Skip to content

Add GEMM performance model support for Jax#139

Merged
gabeweisz merged 48 commits intomainfrom
gw_jax_treeperf
May 6, 2025
Merged

Add GEMM performance model support for Jax#139
gabeweisz merged 48 commits intomainfrom
gw_jax_treeperf

Conversation

@gabeweisz
Copy link
Copy Markdown
Collaborator

Add performance model support for GEMMs for Jax
Adds more enumerated values for standard trace event items
Uses enumerate values for most of the treeperf code

Needs to be sequenced after #132

gabeweisz and others added 30 commits April 10, 2025 10:48
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@gabeweisz gabeweisz requested a review from Copilot May 5, 2025 14:39
@gabeweisz
Copy link
Copy Markdown
Collaborator Author

Hold until #132 is complete

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request adds GEMM performance model support for Jax by extending TraceLens functionalities and updating performance analyses, with associated documentation improvements. Key changes include:

  • Adding detailed GEMM performance metrics examples and Jax trace analysis code in the documentation.
  • Enhancing TraceLens utility and tree-building modules with new enums, event categorizations, and support for Jax-specific GPU performance metrics.
  • Modifying performance model computations to use updated operator naming (e.g. replacing "B" with "Op B") for GEMMs.

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
docs/jax_analyses.md Added examples and code snippets for detailed GEMM performance metrics.
TraceLens/util.py Expanded enums and utility functions; added StrEnum support for metadata.
TraceLens/init.py Updated module exports to include new utility classes.
TraceLens/TreePerf/tree_perf.py Updated event categorization to use new categorizer functions.
TraceLens/TreePerf/jax_analyses.py Introduced Jax-specific performance analysis functions and GEMM support.
TraceLens/TreePerf/init.py Updated exports to include Jax analysis components.
TraceLens/Trace2Tree/trace_to_tree.py Revised tree building logic to use updated keys and event categorizers.
TraceLens/PerfModel/perf_model.py Adjusted GEMM performance computations and parameter naming for consistency.
Comments suppressed due to low confidence (1)

TraceLens/TreePerf/jax_analyses.py:157

  • The regex pattern is a raw string so the {event_key} placeholder is not interpolated; consider using an f-string (e.g. rf"...{event_key}...") if dynamic substitution is intended.
pattern = re.compile(r"^.*value:.*({event_key})\.?(\d+)?*size=(\d+).*: ([a-zA-Z\d].*)\[.*$")

(GW) This was correct - I fixed this in the code. We need a raw string to avoid "invalid escape sequence" warnings in the regex

@gabeweisz gabeweisz requested review from ajassani and Copilot May 5, 2025 14:47
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request adds GEMM performance model support for Jax, extends trace event processing with additional enumerated values, and updates performance model parameters to align with Jax GEMM semantics. Key changes include:

  • New documentation and sample scripts in docs/jax_analyses.md illustrating how to extract and display detailed GEMM performance metrics.
  • Updates to trace event utilities (in util.py, Trace2Tree, and TreePerf) and the introduction of new Jax-specific analysis logic in jax_analyses.py.
  • Modifications to performance model files (perf_model.py) to use new parameter names (e.g. "Op B") for backwards compatibility and clarity.

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.

Show a summary per file
File Description
docs/jax_analyses.md Added code blocks demonstrating GEMM performance metrics for Jax.
TraceLens/util.py Updated enum definitions and utility functions for trace event processing.
TraceLens/init.py Updated imports to expose TraceEventUtils and related classes.
TraceLens/TreePerf/tree_perf.py Integrated Jax support in TreePerfAnalyzer and updated event categorization.
TraceLens/jax_analyses.py Introduced new GEMM analysis routines and updated regex patterns.
TraceLens/TreePerf/init.py Exported new Jax analyses and performance helper classes.
TraceLens/Trace2Tree/trace_to_tree.py Enhanced tree construction to account for updated categorization logic.
TraceLens/PerfModel/perf_model.py Adjusted GEMM parameter names to reflect Jax GEMM semantics (“Op B”).
Comments suppressed due to low confidence (2)

TraceLens/TreePerf/tree_perf.py:47

  • The variable 'jax' is used in this conditional but is not defined within the function. Consider extracting the 'jax' flag from kwargs or providing a default value.
categorizer =  TraceToTree.default_categorizer if not jax else JaxAnalyses.prepare_event_categorizer(data)

TraceLens/jax_analyses.py:157

  • Switching from an f-string to a raw string prevents the variable 'event_key' from being interpolated. Revert to an f-string or explicitly format the regex pattern to include 'event_key'.
pattern = re.compile(r"^.*value:.*({event_key})\.?([\d]+)?.*size=(\d+).*: ([a-zA-Z\d].*)\[.*")


tree = TraceToTree(data['traceEvents'])
return TreePerfAnalyzer(tree, *args, **kwargs)
categorizer = TraceToTree.default_categorizer if not jax else JaxAnalyses.prepare_event_categorizer(data)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this jax variable does not seem to be defined here

@gabeweisz
Copy link
Copy Markdown
Collaborator Author

Per discussion with @ajassani, rolling up one PR with this one and #132

@gabeweisz gabeweisz merged commit 900eccf into main May 6, 2025
@gabeweisz gabeweisz deleted the gw_jax_treeperf branch May 6, 2025 19:37
lauri9 pushed a commit that referenced this pull request Jun 11, 2025
Add performance model support for GEMMs for Jax
Adds more enumerated values for standard trace event items
Uses enumerate values for most of the treeperf code

Needs to be sequenced after #132

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Adeem Jassani <adeem.jassani@amd.com>
Co-authored-by: root <ajassani@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants