Skip to content

Commit f7196e0

Browse files
authored
megatron lm custom flow (#152)
This PR introduces a new custom workflow for generating performance reports for Megatron LM, utilizing trace data and custom performance model adjustments. Introduces a new script for consolidating trace data and outputting Excel reports. Updates performance model mappings to include FusedAttnFunc for SDPA operations.
1 parent f97d5a9 commit f7196e0

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import argparse
2+
import json
3+
import pandas as pd
4+
from TraceLens import TraceToTree
5+
from TraceLens import TreePerfAnalyzer
6+
from TraceLens.PerfModel import dict_cat2names
7+
8+
from TraceLens.PerfModel import SDPA
9+
class transformer_engine_attention(SDPA):
10+
"""
11+
Context: The FusedAttnFunc is a pytorch extention for the attention kernel.
12+
Unfortunately, the args does not have a bool flag for is_causal.
13+
Instead, it has a str arg which is not recorded in the trace.
14+
15+
Solution: Based on the LLM use case we make the assumption that
16+
the attention is always causal.
17+
Since this might not be the case for other use cases,
18+
we dont add this natively to the perf model and instead add it here
19+
"""
20+
@staticmethod
21+
def get_param_details(event):
22+
# ref TransformerEngine/transformer_engine/pytorch/cpp_extensions/fused_attn.py
23+
# https://github.com/NVIDIA/TransformerEngine/blob/51cd441501e8e6dee18c00056f008e1b53b89ebd/transformer_engine/pytorch/attention/dot_product_attention/backends.py#L881
24+
input_dims = event['args']['Input Dims']
25+
q_idx = None
26+
for i, dim in enumerate(input_dims):
27+
if len(dim)==4:
28+
q_idx = i
29+
break
30+
assert q_idx is not None, "query index not found"
31+
q_shape, k_shape, v_shape = input_dims[q_idx: q_idx+3]
32+
B, N_Q, H_Q, d_h = q_shape
33+
assert k_shape == v_shape, f"Key and value shapes are different: {k_shape} != {v_shape}"
34+
_, N_KV, H_KV, _ = k_shape
35+
is_causal = True
36+
dropout_p = 0.0
37+
return {"B": B, "N_Q": N_Q, "H_Q": H_Q, "N_KV": N_KV, "H_KV": H_KV, "d_h": d_h,
38+
"dropout": dropout_p, "causal": is_causal, "flash_impl": False}
39+
40+
def main():
41+
parser = argparse.ArgumentParser(description='Process a JSON trace profile and generate performance report tables.')
42+
parser.add_argument('--profile_json_path', type=str, required=True, help='Path to the profile.json file')
43+
parser.add_argument('--output_xlsx_path', type=str, required=True, help='Path to the output Excel file')
44+
parser.add_argument('--gpu_arch_json_path', type=str, default=None, help='Path to the GPU architecture JSON file')
45+
args = parser.parse_args()
46+
47+
# Load the arch json
48+
gpu_arch_json = None
49+
if args.gpu_arch_json_path:
50+
with open(args.gpu_arch_json_path, 'r') as f:
51+
gpu_arch_json = json.load(f)
52+
perf_analyzer = TreePerfAnalyzer.from_file(profile_filepath=args.profile_json_path, arch=gpu_arch_json)
53+
54+
agg_metrics = ['mean', 'median', 'std', 'min', 'max']
55+
56+
# Generate base DataFrames
57+
df_gpu_timeline = perf_analyzer.get_df_gpu_timeline()
58+
df_kernel_launchers = perf_analyzer.get_df_kernel_launchers()
59+
df_kernel_launchers_summary = perf_analyzer.get_df_kernel_launchers_summary(df_kernel_launchers)
60+
df_kernel_launchers_unique_args = perf_analyzer.get_df_kernel_launchers_unique_args(df_kernel_launchers,
61+
agg_metrics=agg_metrics,
62+
include_pct=True)
63+
64+
# Dictionary to hold the op-specific DataFrames
65+
op_dfs = {}
66+
67+
# update the dict_cat2names to include FusedAttnFunc
68+
dict_cat2names['SDPA'].append('FusedAttnFunc')
69+
dict_name_to_custom_perf_model = {'FusedAttnFunc': transformer_engine_attention}
70+
71+
for op_cat, op_names in dict_cat2names.items():
72+
# Filter events belonging to the current category
73+
op_events = [event for event in perf_analyzer.tree.events if event['name'] in op_names]
74+
75+
if op_cat in ['GEMM', 'UnaryElementwise', 'BinaryElementwise']:
76+
# For GEMM: create a single table that covers both fwd and bwd.
77+
df_ops = perf_analyzer.build_df_perf_metrics(op_events, bwd=False, include_kernel_names=True, include_args=True,
78+
dict_name_to_perf_model=dict_name_to_custom_perf_model)
79+
df_ops = perf_analyzer.summarize_df_perf_metrics(df_ops, agg_metrics)
80+
op_dfs[op_cat] = df_ops
81+
else:
82+
# For FLASH_ATTN and CONV: create separate tables for forward and backward passes.
83+
df_ops_fwd = perf_analyzer.build_df_perf_metrics(op_events, bwd=False, include_kernel_names=True, include_args=True,
84+
dict_name_to_perf_model=dict_name_to_custom_perf_model)
85+
df_ops_fwd = perf_analyzer.summarize_df_perf_metrics(df_ops_fwd, agg_metrics)
86+
df_ops_bwd = perf_analyzer.build_df_perf_metrics(op_events, bwd=True, include_kernel_names=True, include_args=True,
87+
dict_name_to_perf_model=dict_name_to_custom_perf_model)
88+
df_ops_bwd = perf_analyzer.summarize_df_perf_metrics(df_ops_bwd, agg_metrics)
89+
op_dfs[f"{op_cat}_fwd"] = df_ops_fwd
90+
op_dfs[f"{op_cat}_bwd"] = df_ops_bwd
91+
92+
# Write all DataFrames to separate sheets in an Excel workbook
93+
with pd.ExcelWriter(args.output_xlsx_path) as writer:
94+
df_gpu_timeline.to_excel(writer, sheet_name='gpu_timeline', index=False)
95+
df_kernel_launchers_summary.to_excel(writer, sheet_name='kernel_launchers_summary', index=False)
96+
df_kernel_launchers_unique_args.to_excel(writer, sheet_name='kernel_launchers_unique_args', index=False)
97+
98+
# Write each op category DataFrame
99+
for sheet_name, df in op_dfs.items():
100+
df.to_excel(writer, sheet_name=sheet_name, index=False)
101+
102+
print(f"DataFrames successfully written to {args.output_xlsx_path}")
103+
104+
if __name__ == "__main__":
105+
main()

0 commit comments

Comments
 (0)