Add GEMM performance model support for Jax#139
Conversation
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
Hold until #132 is complete |
There was a problem hiding this comment.
Pull Request Overview
This pull request adds GEMM performance model support for Jax by extending TraceLens functionalities and updating performance analyses, with associated documentation improvements. Key changes include:
- Adding detailed GEMM performance metrics examples and Jax trace analysis code in the documentation.
- Enhancing TraceLens utility and tree-building modules with new enums, event categorizations, and support for Jax-specific GPU performance metrics.
- Modifying performance model computations to use updated operator naming (e.g. replacing "B" with "Op B") for GEMMs.
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| docs/jax_analyses.md | Added examples and code snippets for detailed GEMM performance metrics. |
| TraceLens/util.py | Expanded enums and utility functions; added StrEnum support for metadata. |
| TraceLens/init.py | Updated module exports to include new utility classes. |
| TraceLens/TreePerf/tree_perf.py | Updated event categorization to use new categorizer functions. |
| TraceLens/TreePerf/jax_analyses.py | Introduced Jax-specific performance analysis functions and GEMM support. |
| TraceLens/TreePerf/init.py | Updated exports to include Jax analysis components. |
| TraceLens/Trace2Tree/trace_to_tree.py | Revised tree building logic to use updated keys and event categorizers. |
| TraceLens/PerfModel/perf_model.py | Adjusted GEMM performance computations and parameter naming for consistency. |
Comments suppressed due to low confidence (1)
TraceLens/TreePerf/jax_analyses.py:157
- The regex pattern is a raw string so the {event_key} placeholder is not interpolated; consider using an f-string (e.g. rf"...{event_key}...") if dynamic substitution is intended.
pattern = re.compile(r"^.*value:.*({event_key})\.?(\d+)?*size=(\d+).*: ([a-zA-Z\d].*)\[.*$")
(GW) This was correct - I fixed this in the code. We need a raw string to avoid "invalid escape sequence" warnings in the regex
There was a problem hiding this comment.
Pull Request Overview
This pull request adds GEMM performance model support for Jax, extends trace event processing with additional enumerated values, and updates performance model parameters to align with Jax GEMM semantics. Key changes include:
- New documentation and sample scripts in docs/jax_analyses.md illustrating how to extract and display detailed GEMM performance metrics.
- Updates to trace event utilities (in util.py, Trace2Tree, and TreePerf) and the introduction of new Jax-specific analysis logic in jax_analyses.py.
- Modifications to performance model files (perf_model.py) to use new parameter names (e.g. "Op B") for backwards compatibility and clarity.
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| docs/jax_analyses.md | Added code blocks demonstrating GEMM performance metrics for Jax. |
| TraceLens/util.py | Updated enum definitions and utility functions for trace event processing. |
| TraceLens/init.py | Updated imports to expose TraceEventUtils and related classes. |
| TraceLens/TreePerf/tree_perf.py | Integrated Jax support in TreePerfAnalyzer and updated event categorization. |
| TraceLens/jax_analyses.py | Introduced new GEMM analysis routines and updated regex patterns. |
| TraceLens/TreePerf/init.py | Exported new Jax analyses and performance helper classes. |
| TraceLens/Trace2Tree/trace_to_tree.py | Enhanced tree construction to account for updated categorization logic. |
| TraceLens/PerfModel/perf_model.py | Adjusted GEMM parameter names to reflect Jax GEMM semantics (“Op B”). |
Comments suppressed due to low confidence (2)
TraceLens/TreePerf/tree_perf.py:47
- The variable 'jax' is used in this conditional but is not defined within the function. Consider extracting the 'jax' flag from kwargs or providing a default value.
categorizer = TraceToTree.default_categorizer if not jax else JaxAnalyses.prepare_event_categorizer(data)
TraceLens/jax_analyses.py:157
- Switching from an f-string to a raw string prevents the variable 'event_key' from being interpolated. Revert to an f-string or explicitly format the regex pattern to include 'event_key'.
pattern = re.compile(r"^.*value:.*({event_key})\.?([\d]+)?.*size=(\d+).*: ([a-zA-Z\d].*)\[.*")
|
|
||
| tree = TraceToTree(data['traceEvents']) | ||
| return TreePerfAnalyzer(tree, *args, **kwargs) | ||
| categorizer = TraceToTree.default_categorizer if not jax else JaxAnalyses.prepare_event_categorizer(data) |
There was a problem hiding this comment.
this jax variable does not seem to be defined here
Add performance model support for GEMMs for Jax Adds more enumerated values for standard trace event items Uses enumerate values for most of the treeperf code Needs to be sequenced after #132 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Adeem Jassani <adeem.jassani@amd.com> Co-authored-by: root <ajassani@amd.com>
Add performance model support for GEMMs for Jax
Adds more enumerated values for standard trace event items
Uses enumerate values for most of the treeperf code
Needs to be sequenced after #132