Skip to content

Commit e8aa105

Browse files
author
Prashant Kumar
committed
Divide iree_utils and do module imports on function calls.
1 parent 08eda2c commit e8aa105

File tree

147 files changed

+4710
-2667
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

147 files changed

+4710
-2667
lines changed

.github/workflows/nightly.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ jobs:
5353
- name: Install dependencies
5454
run: |
5555
python -m pip install --upgrade pip
56-
python -m pip install flake8 pytest yapf toml
56+
python -m pip install flake8 pytest toml
5757
if [ -f requirements.txt ]; then pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
5858
- name: Lint with flake8
5959
run: |
6060
# stop the build if there are Python syntax errors or undefined names
6161
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
6262
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
6363
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
64-
yapf -i --style .style.yapf shark/*.py
6564
6665
- name: Build and validate the package
6766
run: |

.github/workflows/python-format.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name: black-formatter
2+
on: [pull_request]
3+
jobs:
4+
build:
5+
runs-on: ubuntu-latest
6+
steps:
7+
- uses: actions/checkout@v1
8+
- name: Set up Python 3.10
9+
uses: actions/setup-python@v1
10+
with:
11+
python-version: 3.10.5
12+
- name: Install Black
13+
run: pip install black
14+
- name: Run formatter check on the entire project.
15+
run: black --line-length 80 --check .

.github/workflows/test-models.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,14 @@ jobs:
3737
- name: Install dependencies
3838
run: |
3939
python -m pip install --upgrade pip
40-
python -m pip install flake8 pytest yapf toml
40+
python -m pip install flake8 pytest toml
4141
4242
- name: Lint with flake8
4343
run: |
4444
# stop the build if there are Python syntax errors or undefined names
4545
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
4646
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
4747
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
48-
yapf -i --style .style.yapf shark/*.py
4948
5049
- name: Validate Models
5150
run: |

benchmarks/hf_model_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
"--model_name",
77
type=str,
88
required=True,
9-
help=
10-
"Specifies name of HF model to benchmark. (For exmaple \"microsoft/MiniLM-L12-H384-uncased\""
9+
help='Specifies name of HF model to benchmark. (For exmaple "microsoft/MiniLM-L12-H384-uncased"',
1110
)
1211
load_args, unknown = parser.parse_known_args()
1312

1413
if __name__ == "__main__":
1514
model_name = load_args.model_name
1615
test_input = torch.randint(2, (1, 128))
17-
shark_module = SharkHFBenchmarkRunner(model_name, (test_input,),
18-
jit_trace=True)
16+
shark_module = SharkHFBenchmarkRunner(
17+
model_name, (test_input,), jit_trace=True
18+
)
1919
shark_module.benchmark_c()
2020
shark_module.benchmark_python((test_input,))
2121
shark_module.benchmark_torch(test_input)

benchmarks/hf_transformer.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import torch
2-
from shark.shark_runner import SharkBenchmarkRunner
2+
from shark.shark_benchmark_runner import SharkBenchmarkRunner
33
from shark.parser import shark_args
44
from transformers import AutoTokenizer, AutoModelForSequenceClassification
5-
from onnxruntime.transformers.benchmark import run_pytorch, run_tensorflow, run_onnxruntime
5+
from onnxruntime.transformers.benchmark import (
6+
run_pytorch,
7+
run_tensorflow,
8+
run_onnxruntime,
9+
)
610
from onnxruntime.transformers.huggingface_models import MODELS
711
from onnxruntime.transformers.benchmark_helper import ConfigModifier, Precision
812
import os
913
import psutil
1014

1115

1216
class OnnxFusionOptions(object):
13-
1417
def __init__(self):
1518
self.disable_gelu = False
1619
self.disable_layer_norm = False
@@ -25,17 +28,13 @@ def __init__(self):
2528

2629

2730
class HuggingFaceLanguage(torch.nn.Module):
28-
2931
def __init__(self, hf_model_name):
3032
super().__init__()
3133
self.model = AutoModelForSequenceClassification.from_pretrained(
3234
hf_model_name, # The pretrained model.
33-
num_labels=
34-
2, # The number of output labels--2 for binary classification.
35-
output_attentions=
36-
False, # Whether the model returns attentions weights.
37-
output_hidden_states=
38-
False, # Whether the model returns all hidden-states.
35+
num_labels=2, # The number of output labels--2 for binary classification.
36+
output_attentions=False, # Whether the model returns attentions weights.
37+
output_hidden_states=False, # Whether the model returns all hidden-states.
3938
torchscript=True,
4039
)
4140

@@ -62,8 +61,16 @@ def __init__(
6261
)
6362
self.model_name = model_name
6463
model = HuggingFaceLanguage(model_name)
65-
SharkBenchmarkRunner.__init__(self, model, input, dynamic, self.device,
66-
jit_trace, from_aot, frontend)
64+
SharkBenchmarkRunner.__init__(
65+
self,
66+
model,
67+
input,
68+
dynamic,
69+
self.device,
70+
jit_trace,
71+
from_aot,
72+
frontend,
73+
)
6774

6875
def benchmark_torch(self, inputs):
6976
use_gpu = self.device == "gpu"
@@ -74,10 +81,20 @@ def benchmark_torch(self, inputs):
7481
sequence_lengths = [inputs.shape[-1]]
7582
cache_dir = os.path.join(".", "cache_models")
7683
verbose = False
77-
result = run_pytorch(use_gpu, [self.model_name], None, config_modifier,
78-
Precision.FLOAT32, num_threads, batch_sizes,
79-
sequence_lengths, shark_args.num_iterations, False,
80-
cache_dir, verbose)
84+
result = run_pytorch(
85+
use_gpu,
86+
[self.model_name],
87+
None,
88+
config_modifier,
89+
Precision.FLOAT32,
90+
num_threads,
91+
batch_sizes,
92+
sequence_lengths,
93+
shark_args.num_iterations,
94+
False,
95+
cache_dir,
96+
verbose,
97+
)
8198
print(
8299
f"ONNX Pytorch-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
83100
)
@@ -92,10 +109,19 @@ def benchmark_tf(self, inputs):
92109
sequence_lengths = [inputs.shape[-1]]
93110
cache_dir = os.path.join(".", "cache_models")
94111
verbose = False
95-
result = run_tensorflow(use_gpu, [self.model_name], None,
96-
config_modifier, Precision.FLOAT32, num_threads,
97-
batch_sizes, sequence_lengths,
98-
shark_args.num_iterations, cache_dir, verbose)
112+
result = run_tensorflow(
113+
use_gpu,
114+
[self.model_name],
115+
None,
116+
config_modifier,
117+
Precision.FLOAT32,
118+
num_threads,
119+
batch_sizes,
120+
sequence_lengths,
121+
shark_args.num_iterations,
122+
cache_dir,
123+
verbose,
124+
)
99125
print(
100126
f"ONNX TF-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
101127
)
@@ -105,7 +131,8 @@ def benchmark_onnx(self, inputs):
105131
print(
106132
f"{self.model_name} is currently not supported in ORT's HF. Check \
107133
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
108-
for currently supported models. Exiting benchmark ONNX.")
134+
for currently supported models. Exiting benchmark ONNX."
135+
)
109136
return
110137
use_gpu = self.device == "gpu"
111138
num_threads = psutil.cpu_count(logical=False)
@@ -121,17 +148,34 @@ def benchmark_onnx(self, inputs):
121148
use_raw_attention_mask = True
122149
model_fusion_statistics = {}
123150
overwrite = False
124-
model_source = "pt" #Either "pt" or "tf"
151+
model_source = "pt" # Either "pt" or "tf"
125152
provider = None
126153
config_modifier = ConfigModifier(None)
127154
onnx_args = OnnxFusionOptions()
128155
result = run_onnxruntime(
129-
use_gpu, provider, [self.model_name], None, config_modifier,
130-
Precision.FLOAT32, num_threads, batch_sizes, sequence_lengths,
131-
shark_args.num_iterations, input_counts, optimize_onnx,
132-
validate_onnx, cache_dir, onnx_dir, verbose, overwrite,
133-
disable_ort_io_binding, use_raw_attention_mask,
134-
model_fusion_statistics, model_source, onnx_args)
156+
use_gpu,
157+
provider,
158+
[self.model_name],
159+
None,
160+
config_modifier,
161+
Precision.FLOAT32,
162+
num_threads,
163+
batch_sizes,
164+
sequence_lengths,
165+
shark_args.num_iterations,
166+
input_counts,
167+
optimize_onnx,
168+
validate_onnx,
169+
cache_dir,
170+
onnx_dir,
171+
verbose,
172+
overwrite,
173+
disable_ort_io_binding,
174+
use_raw_attention_mask,
175+
model_fusion_statistics,
176+
model_source,
177+
onnx_args,
178+
)
135179
print(
136180
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
137181
)

0 commit comments

Comments
 (0)