Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d9fde3a
GEMM stuff works
gabeweisz Apr 10, 2025
dcf825e
merge
gabeweisz Apr 10, 2025
186eef8
fix potential leak
gabeweisz Apr 10, 2025
054f347
add basic documentation
gabeweisz Apr 10, 2025
bc05162
update doc to explain how to use the code
gabeweisz Apr 10, 2025
8c40a27
merged main
gabeweisz Apr 11, 2025
dcbc13c
starting steps for treeperf for jax
gabeweisz Apr 11, 2025
301e82c
Merge branch 'main' of https://github.com/AMD-AIG-AIMA/TraceLens into…
gabeweisz Apr 21, 2025
9824fae
change Jax memset to count as compute to be more consistent with the …
gabeweisz Apr 21, 2025
f5a1b4e
change factory pattern to member
gabeweisz Apr 21, 2025
d929a05
fix bug with save preprocessed
gabeweisz Apr 22, 2025
fe609bc
add param to save
gabeweisz Apr 22, 2025
079e566
add metadata processor
gabeweisz Apr 23, 2025
96beb20
rename func for metadata extraction
gabeweisz Apr 23, 2025
c99a789
Merge branch 'main' of https://github.com/amd-aig-aima/TraceLens into…
gabeweisz Apr 25, 2025
3e02d4f
add code for jax-based tree in tracetotree
gabeweisz Apr 29, 2025
b410d21
update docs
gabeweisz Apr 29, 2025
8e37ba8
update exports
gabeweisz Apr 29, 2025
9d0a501
merge main
gabeweisz Apr 29, 2025
51ce3f6
Apply suggestions from code review
gabeweisz Apr 29, 2025
94ee162
Update TraceLens/TreePerf/jax_analyses.py
gabeweisz Apr 29, 2025
f31772f
Update TraceLens/util.py
gabeweisz Apr 29, 2025
573b965
fix trace_to_tree for jax
gabeweisz Apr 30, 2025
431c32b
remove premature jax integration
gabeweisz Apr 30, 2025
1b18805
tree perf updates will go in a different PR
gabeweisz Apr 30, 2025
1cc2a4e
Merge branch 'gw_jax_tree' of https://github.com/amd-aig-aima/TraceLe…
gabeweisz Apr 30, 2025
f1a996b
remove premature jax integration
gabeweisz Apr 30, 2025
0695fe1
fix preprocessing
gabeweisz Apr 30, 2025
30a635a
give the metadata skipper a better name
gabeweisz Apr 30, 2025
072800b
fix doc
gabeweisz Apr 30, 2025
730f77d
wip
gabeweisz May 1, 2025
47a15d0
add workaround for older python
gabeweisz May 1, 2025
4f926ed
add fp32 and fp64 support to gemm extraction
gabeweisz May 1, 2025
b7b0e6b
fix gemm type generation
gabeweisz May 2, 2025
ef7f667
support types directly in operand
gabeweisz May 2, 2025
34889d7
gemm_perf_metrics just about works
gabeweisz May 2, 2025
55bcfc2
use raw strings for all regex, change 'B' to 'GEMM Batch' for all GEMMs
gabeweisz May 5, 2025
25036a0
universally list operator batch as 'Op B'
gabeweisz May 5, 2025
8581318
Merge branch 'main' into gw_jax_tree
ajassani May 5, 2025
45cfe78
update doc and add arch to helper
gabeweisz May 5, 2025
ad22775
remove unwanted change
gabeweisz May 5, 2025
d253230
merge
gabeweisz May 5, 2025
98a768c
merge other copilot fixes
gabeweisz May 5, 2025
d599279
fix error due to changing f string to raw string
gabeweisz May 5, 2025
2bf32cb
change to rf string
gabeweisz May 5, 2025
6dc34be
revert Op B to B
ajassani May 6, 2025
246100e
Merge branch 'main' of https://github.com/amd-aig-aima/TraceLens into…
gabeweisz May 6, 2025
fe7c793
restore jax parameter
gabeweisz May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions TraceLens/PerfModel/perf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ def dim_efficiency_func(num_cus, M, N, K, mt_m, mt_n, depth_u):
M_pad = math.ceil(M / mt_m) * mt_m
N_pad = math.ceil(N / mt_n) * mt_n
tile_eff = (M * N) / (M_pad * N_pad)

# Wave quantization
num_blocks = M_pad * N_pad // (mt_m * mt_n)
num_rounds = math.ceil(num_blocks / num_cus)
wq_eff = num_blocks / (num_rounds * num_cus)

# Net dimensional efficiency = tile efficiency * wave efficiency
dim_eff = tile_eff * wq_eff
return {
Expand All @@ -167,7 +167,7 @@ def dim_efficiency_func(num_cus, M, N, K, mt_m, mt_n, depth_u):
'wq_eff': wq_eff,
'dim_eff': dim_eff,
}

def dim_efficiency(self, arch_dict):
"""
args:
Expand Down
152 changes: 82 additions & 70 deletions TraceLens/Trace2Tree/trace_to_tree.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions TraceLens/TreePerf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .tree_perf import TreePerfAnalyzer
from .gpu_event_analyser import GPUEventAnalyser, PytorchGPUEventAnalyser, JaxGPUEventAnalyser
from .jax_analyses import JaxAnalyses
from .jax_analyses import JaxAnalyses, JaxProfileProcessor

__all__ = ["TreePerfAnalyzer", "GPUEventAnalyser", "PytorchGPUEventAnalyser", "JaxGPUEventAnalyser", "JaxAnalyses"]
__all__ = ["TreePerfAnalyzer", "GPUEventAnalyser", "PytorchGPUEventAnalyser", "JaxGPUEventAnalyser", "JaxAnalyses", "JaxProfileProcessor"]
254 changes: 216 additions & 38 deletions TraceLens/TreePerf/jax_analyses.py

Large diffs are not rendered by default.

87 changes: 42 additions & 45 deletions TraceLens/TreePerf/tree_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,46 +23,47 @@
import json
import gzip
from collections import defaultdict
from typing import Dict, Any
from typing import Dict, Any, Callable

# TODO: warning should show the stack as well
import warnings
import pprint
import pandas as pd
from ..PerfModel.torch_op_mapping import op_to_perf_model_class_map
from .gpu_event_analyser import GPUEventAnalyser
from .gpu_event_analyser import GPUEventAnalyser, JaxGPUEventAnalyser
from .jax_analyses import JaxAnalyses
from ..Trace2Tree.trace_to_tree import TraceToTree
from ..util import DataLoader, TraceEventUtils

class TreePerfAnalyzer:
@staticmethod
def from_file(profile_filepath, *args, **kwargs) -> "TreePerfAnalyzer":
def from_file(profile_filepath, jax: bool = False, *args, **kwargs) -> "TreePerfAnalyzer":
# Creates a TreePerfAnalyzer from the trace in the provided filepath.
# *args, **kwargs are passed to the TreePerfAnalyzer constructor.

if profile_filepath.endswith('.json'):
with open(profile_filepath, 'r') as f:
data = json.load(f)
elif profile_filepath.endswith('.gz'):
with gzip.open(profile_filepath, 'rt') as f:
data = json.load(f)
else:
raise ValueError("Profile file should be either .json or .gz")
data = DataLoader.load_data(profile_filepath)
data = data['traceEvents']

tree = TraceToTree(data['traceEvents'])
return TreePerfAnalyzer(tree, *args, **kwargs)
categorizer = TraceToTree.default_categorizer if not jax else JaxAnalyses.prepare_event_categorizer(data)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this jax variable does not seem to be defined here

data = data if not jax else TraceEventUtils.non_metadata_events(data)
tree = TraceToTree(data, event_to_category=categorizer)
return TreePerfAnalyzer(tree, jax=jax, event_to_category=categorizer, *args, **kwargs)

def __init__(self, tree: TraceToTree, add_python_func=False, arch=None):
def __init__(self, tree: TraceToTree, add_python_func=False, arch=None, jax = False, event_to_category: Callable[[dict], str] = TraceEventUtils.default_categorizer):
self.jax = jax
self.GPUEventAnalyser = GPUEventAnalyser if not jax else JaxGPUEventAnalyser
self.tree = tree
self.add_python_func = add_python_func
self.add_python_func = add_python_func
self.arch = arch
self.event_to_category = event_to_category
# we check if profile contains python func events
self.with_python_stack = next((True for event in self.tree.events if event.get('cat') == 'python_func'), False)
self.with_python_stack = next((True for event in self.tree.events if self.event_to_category(event) == 'python_func'), False)
self.tree.build_tree(add_python_func=add_python_func)

def agg_kernels_in_subtree(self, event, filter_func=None, verbose=False):
if filter_func is None:
filter_func = lambda x: True
if event.get('cat') in {'kernel', 'gpu_memcpy', 'gpu_memset'}:
if self.event_to_category(event) in {'kernel', 'gpu_memcpy', 'gpu_memset'}:
if not filter_func(event):
return 0, []
if verbose:
Expand Down Expand Up @@ -91,7 +92,7 @@ def non_data_mov_filter(event):
DATA_MOVEMENT_PATTERNS = ['at::native::direct_copy_kernel_cuda', 'transpose_']
return not any(pattern in event['name'] for pattern in DATA_MOVEMENT_PATTERNS)

def compute_perf_metrics(self, event, bwd=False,
def compute_perf_metrics(self, event, bwd=False,
non_data_mov=False, perf_model_class=None,
detail_level=0):

Expand Down Expand Up @@ -154,10 +155,10 @@ def compute_fwd_perf_metrics(self, event, non_data_mov=False):
return self.compute_perf_metrics(event, bwd=False, non_data_mov=non_data_mov)
def compute_bwd_perf_metrics(self, event, non_data_mov=False):
return self.compute_perf_metrics(event, bwd=True, non_data_mov=non_data_mov)
def build_df_perf_metrics(self, events, bwd=False,

def build_df_perf_metrics(self, events, bwd=False,
non_data_mov=False, include_kernel_names=False, include_args=False,
dict_name_to_perf_model=None,
dict_name_to_perf_model=None,
detail_level=0):
if len(events) == 0:
warnings.warn("Input list of events is empty. Returning an empty DataFrame.")
Expand All @@ -166,22 +167,18 @@ def build_df_perf_metrics(self, events, bwd=False,
list_warn_non_zero_flops_and_zero_time = []
list_no_bwd_events = []
for event in events:
metrics_event = {'cat': event['cat'], 'name': event['name'],
metrics_event = {'cat': self.event_to_category(event), 'name': event['name'],
'UID': event['UID'],
'pid': event['pid'], 'tid': event['tid'],
'external_id': event['args']['External id']}
'external_id': event['args'].get('External id')}
if include_args:
args_cols = ['Input Dims', 'Input type', 'Input Strides', 'Concrete Inputs']
for arg in args_cols:
if arg in event['args']:
metrics_event[arg] = event['args'][arg]
else:
metrics_event[arg] = None
metrics_event.update((arg, event['args'].get(arg)) for arg in args_cols)
if dict_name_to_perf_model and event['name'] in dict_name_to_perf_model:
perf_model_class = dict_name_to_perf_model[event['name']]
else:
perf_model_class = None
dict_perf_metrics = self.compute_perf_metrics(event, bwd=bwd,
dict_perf_metrics = self.compute_perf_metrics(event, bwd=bwd,
non_data_mov=non_data_mov, perf_model_class=perf_model_class,
detail_level=detail_level)
# handle warnings
Expand Down Expand Up @@ -275,7 +272,7 @@ def get_kernel_launchers(self, include_nccl=False):
# by checking if grandchildren of CPU operations are kernel events.
kernel_launchers = []
for event in self.tree.events:
if event.get('cat') != 'cpu_op':
if self.event_to_category(event) != 'cpu_op':
continue
kernel_launcher = False
# total_direct_kernel_time = 0
Expand All @@ -285,7 +282,7 @@ def get_kernel_launchers(self, include_nccl=False):
child = self.tree.events_by_uid[child_UID]
for grand_child_UID in child.get('children', []):
grand_child = self.tree.events_by_uid[grand_child_UID]
is_kernel = grand_child.get('cat') == 'kernel'
is_kernel = self.event_to_category(grand_child) == 'kernel'
is_nccl = 'nccl' in grand_child['name']
should_include = is_kernel and (include_nccl or not is_nccl)
if should_include:
Expand Down Expand Up @@ -321,7 +318,7 @@ def list_to_tuple(obj):
if id_cols:
metrics_event['pid'] = event['pid']
metrics_event['tid'] = event['tid']
metrics_event['external_id'] = event['args']['External id']
metrics_event['external_id'] = event['args'].get('External id')
if include_kernel_names:
metrics_event['kernel_names'] = event['kernel_names']
rows.append(metrics_event)
Expand Down Expand Up @@ -372,7 +369,7 @@ def get_df_kernel_launchers_summary_by_shape(df_kernel_launchers, name):
return df_agg

@staticmethod
def get_df_kernel_launchers_unique_args(df_kernel_launchers: pd.DataFrame,
def get_df_kernel_launchers_unique_args(df_kernel_launchers: pd.DataFrame,
event_name=None, agg_metrics=['mean'], include_pct=False) -> pd.DataFrame:
"""
Generate a DataFrame with unique arguments for each operation in the input DataFrame.
Expand All @@ -393,7 +390,7 @@ def get_df_kernel_launchers_unique_args(df_kernel_launchers: pd.DataFrame,
df_filtered = df_kernel_launchers[df_kernel_launchers['name'] == event_name].copy()
else:
df_filtered = df_kernel_launchers.copy()

# 1. Create string representations of the grouping columns - so we can group by them
str_col_names, actual_grouping_cols = [], []
for col in grouping_cols_original:
Expand All @@ -405,7 +402,7 @@ def get_df_kernel_launchers_unique_args(df_kernel_launchers: pd.DataFrame,
str_col_names.append(str_col_name)
if not str_col_names:
raise ValueError("No valid columns found to group by.")

# 2. Aggregate the DataFrame by the string representations of the grouping columns
agg_dict = {}
if 'total_direct_kernel_time' in df_filtered.columns:
Expand All @@ -423,7 +420,7 @@ def get_df_kernel_launchers_unique_args(df_kernel_launchers: pd.DataFrame,
df_unique_args = df_filtered.groupby(str_col_names, dropna=False, sort=False).agg(agg_dict)
df_unique_args.columns = ['_'.join(col).strip() for col in df_unique_args.columns.values]
df_unique_args.reset_index(inplace=True)

# 3. Rename columns for clarity
rename_map = {'UID_count': 'operation_count'}
for col in columns_to_keep_first:
Expand All @@ -444,7 +441,7 @@ def get_df_kernel_launchers_unique_args(df_kernel_launchers: pd.DataFrame,
# 5. Sort the DataFrame by the sum of total_direct_kernel_time
if 'total_direct_kernel_time_sum' in df_unique_args.columns:
df_unique_args = df_unique_args.sort_values(by="total_direct_kernel_time_sum", ascending=False).reset_index(drop=True)

# 6. Calculate percentage of total time and cumulative percentage if requested
if include_pct and 'total_direct_kernel_time_sum' in df_unique_args.columns:
total_duration_ms = df_unique_args['total_direct_kernel_time_sum'].sum()
Expand All @@ -453,12 +450,12 @@ def get_df_kernel_launchers_unique_args(df_kernel_launchers: pd.DataFrame,
return df_unique_args

def get_df_gpu_timeline(self):
kernel_events = [event for event in self.tree.events if event.get('cat') in {'kernel', 'gpu_memcpy', 'gpu_memset'} and event.get('tree')]
gpu_event_analyser = GPUEventAnalyser(kernel_events)
kernel_events = [event for event in self.tree.events if self.event_to_category(event) in {'kernel', 'gpu_memcpy', 'gpu_memset'} and event.get('tree')]
gpu_event_analyser = self.GPUEventAnalyser(kernel_events)
df = gpu_event_analyser.get_breakdown_df()
return df

def get_kernel_details(self, kernel_event,
def get_kernel_details(self, kernel_event,
launcher_detail=False, cpu_op_detail = True, nn_module_detail=False):
"""
Extract detailed information for a given kernel event.
Expand All @@ -483,7 +480,7 @@ def list_to_tuple(obj):
return tuple(list_to_tuple(item) for item in obj) if isinstance(obj, list) else obj

# Verify that the event is a kernel event.
if kernel_event.get('cat') != 'kernel':
if self.event_to_category(kernel_event) != 'kernel':
return None

kernel_details = {
Expand All @@ -507,7 +504,7 @@ def list_to_tuple(obj):
cpu_op = None
evt = launcher
while evt:
if evt.get('cat') == 'cpu_op':
if self.event_to_category(evt) == 'cpu_op':
cpu_op = evt
break
evt = self.tree.get_parent_event(evt)
Expand Down Expand Up @@ -542,7 +539,7 @@ def list_to_tuple(obj):
# Attempt to find the nn.Module parent event.
evt = kernel_event
while evt:
if evt.get('cat') == 'python_function' and evt['name'].startswith('nn.Module:'):
if self.event_to_category(evt) == 'python_function' and evt['name'].startswith('nn.Module:'):
nn_module_event = evt
break
evt = self.tree.get_parent_event(evt)
Expand Down Expand Up @@ -589,9 +586,9 @@ def get_df_kernels(self,

# Extract details for all kernel events.
for event in self.tree.events:
if event.get('cat') != 'kernel':
if self.event_to_category(event) != 'kernel':
continue
details = self.get_kernel_details(event,
details = self.get_kernel_details(event,
launcher_detail=launcher_detail,
cpu_op_detail=cpu_op_detail,
nn_module_detail=nn_module_detail)
Expand Down
5 changes: 3 additions & 2 deletions TraceLens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .TraceFusion.trace_fuse import TraceFuse
from .Trace2Tree.trace_to_tree import TraceToTree
from .NcclAnalyser.nccl_analyser import NcclAnalyser
from .util import DataLoader
from .util import DataLoader,TraceEventUtils
from .PerfModel import *
from .EventReplay.event_replay import EventReplayer

Expand All @@ -20,5 +20,6 @@
"PerfModel",
"EventReplay",
"EventReplayer",
"DataLoader"
"DataLoader",
"TraceEventUtils",
]
Loading