Skip to content

Commit 6f8c194

Browse files
committed
2 parents 2bb979e + 5523ca5 commit 6f8c194

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+902
-1713
lines changed

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ jobs:
125125
needs: [Build-Camb]
126126
runs-on: github-poc-ci
127127
env:
128-
MLU_REQUESTS: 1
128+
MLU_REQUESTS: 4
129129
steps:
130130
- name: Run-test
131131
run: |
@@ -210,7 +210,7 @@ jobs:
210210
needs: [Build-Camb-Pt211]
211211
runs-on: github-poc-ci
212212
env:
213-
MLU_REQUESTS: 1
213+
MLU_REQUESTS: 4
214214
steps:
215215
- name: Run-test
216216
run: |

dicp/dicp/vendor/AscendGraph/codegen/ascend.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None):
6161
self.py_output_names = []
6262
self.graph_output_names = []
6363
self.build_options = []
64-
self.output_nodes = []
6564

6665
self.folder = folder
6766
self.graph_key = graph_key
@@ -195,26 +194,8 @@ def parse_outputs(self):
195194
self.py_output_names.append(str(node))
196195
self.output_args = real_output_args
197196

198-
if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0:
199-
for output in self.output_args:
200-
info = {}
201-
info['format'] = 'ND'
202-
if hasattr(output, 'meta'):
203-
output = output.meta['val']
204-
if isinstance(output, torch.SymInt):
205-
info['data_type'] = 'INT32'
206-
elif isinstance(output, torch.SymBool):
207-
info['data_type'] = 'BOOL'
208-
info['data_type'] = get_ascend_dtype(output.dtype)
209-
self.output_nodes.append(info)
210197
if len(self.assign_args) > 0:
211198
self.graph_output_names.extend(list(zip(*self.assign_args))[0])
212-
for item in self.assign_args:
213-
index = item[1]
214-
info = {}
215-
info['format'] = self.data_nodes[index]['format']
216-
info['data_type'] = self.data_nodes[index]['data_type']
217-
self.output_nodes.append(info)
218199

219200
def gen_import_code(self):
220201
self.import_code.splice(
@@ -225,7 +206,7 @@ def gen_import_code(self):
225206
import random
226207
from torch import empty_strided, as_strided, device
227208
from dicp.dynamo_bridge.compile import AsyncCompileKernel
228-
from dicp.vendor.AscendGraph.compile_job import AscendGECompileAclRunJob, AscendGECompileGERunJob
209+
from dicp.vendor.AscendGraph.compile_job import AscendCompileJob
229210
230211
aten = torch.ops.aten
231212
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
@@ -480,19 +461,16 @@ def gen_graph_json(self):
480461
"build_options": self.build_options,
481462
"data_nodes": self.data_nodes,
482463
"common_nodes": self.common_nodes,
483-
"output_nodes": self.output_nodes,
484464
}
485465
self.remove_symint(graph)
486466
return json.dumps(graph)
487467

488468
def gen_compile_graph_code(self):
489469
compile_graph_code = IndentedBuffer()
490470
graph_json = self.gen_graph_json()
491-
compile_job_type = os.environ.get("DICP_ASCEND_COMPILE_JOB_TYPE", "AscendGECompileGERunJob")
492-
assert compile_job_type in ["AscendGECompileGERunJob", "AscendGECompileAclRunJob"]
493471
compile_graph_code.splice(
494472
f"""
495-
ascend_compile_job = {compile_job_type}('''{graph_json}''')
473+
ascend_compile_job = AscendCompileJob('''{graph_json}''')
496474
async_compile = AsyncCompileKernel()
497475
kernel_cpp_0 = async_compile.compile_kernel(ascend_compile_job)
498476
""", strip=True

0 commit comments

Comments
 (0)