@@ -61,7 +61,6 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None):
61
61
self .py_output_names = []
62
62
self .graph_output_names = []
63
63
self .build_options = []
64
- self .output_nodes = []
65
64
66
65
self .folder = folder
67
66
self .graph_key = graph_key
@@ -195,26 +194,8 @@ def parse_outputs(self):
195
194
self .py_output_names .append (str (node ))
196
195
self .output_args = real_output_args
197
196
198
- if len (self .sym_in_args ) > 0 or len (self .sym_to_inputs ) > 0 :
199
- for output in self .output_args :
200
- info = {}
201
- info ['format' ] = 'ND'
202
- if hasattr (output , 'meta' ):
203
- output = output .meta ['val' ]
204
- if isinstance (output , torch .SymInt ):
205
- info ['data_type' ] = 'INT32'
206
- elif isinstance (output , torch .SymBool ):
207
- info ['data_type' ] = 'BOOL'
208
- info ['data_type' ] = get_ascend_dtype (output .dtype )
209
- self .output_nodes .append (info )
210
197
if len (self .assign_args ) > 0 :
211
198
self .graph_output_names .extend (list (zip (* self .assign_args ))[0 ])
212
- for item in self .assign_args :
213
- index = item [1 ]
214
- info = {}
215
- info ['format' ] = self .data_nodes [index ]['format' ]
216
- info ['data_type' ] = self .data_nodes [index ]['data_type' ]
217
- self .output_nodes .append (info )
218
199
219
200
def gen_import_code (self ):
220
201
self .import_code .splice (
@@ -225,7 +206,7 @@ def gen_import_code(self):
225
206
import random
226
207
from torch import empty_strided, as_strided, device
227
208
from dicp.dynamo_bridge.compile import AsyncCompileKernel
228
- from dicp.vendor.AscendGraph.compile_job import AscendGECompileAclRunJob, AscendGECompileGERunJob
209
+ from dicp.vendor.AscendGraph.compile_job import AscendCompileJob
229
210
230
211
aten = torch.ops.aten
231
212
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
@@ -480,19 +461,16 @@ def gen_graph_json(self):
480
461
"build_options" : self .build_options ,
481
462
"data_nodes" : self .data_nodes ,
482
463
"common_nodes" : self .common_nodes ,
483
- "output_nodes" : self .output_nodes ,
484
464
}
485
465
self .remove_symint (graph )
486
466
return json .dumps (graph )
487
467
488
468
def gen_compile_graph_code (self ):
489
469
compile_graph_code = IndentedBuffer ()
490
470
graph_json = self .gen_graph_json ()
491
- compile_job_type = os .environ .get ("DICP_ASCEND_COMPILE_JOB_TYPE" , "AscendGECompileGERunJob" )
492
- assert compile_job_type in ["AscendGECompileGERunJob" , "AscendGECompileAclRunJob" ]
493
471
compile_graph_code .splice (
494
472
f"""
495
- ascend_compile_job = { compile_job_type } ('''{ graph_json } ''')
473
+ ascend_compile_job = AscendCompileJob ('''{ graph_json } ''')
496
474
async_compile = AsyncCompileKernel()
497
475
kernel_cpp_0 = async_compile.compile_kernel(ascend_compile_job)
498
476
""" , strip = True
0 commit comments