-
Notifications
You must be signed in to change notification settings - Fork 10
Add GEMM performance model support for Jax #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
d9fde3a
GEMM stuff works
gabeweisz dcf825e
merge
gabeweisz 186eef8
fix potential leak
gabeweisz 054f347
add basic documentation
gabeweisz bc05162
update doc to explain how to use the code
gabeweisz 8c40a27
merged main
gabeweisz dcbc13c
starting steps for treeperf for jax
gabeweisz 301e82c
Merge branch 'main' of https://github.com/AMD-AIG-AIMA/TraceLens into…
gabeweisz 9824fae
change Jax memset to count as compute to be more consistent with the …
gabeweisz f5a1b4e
change factory pattern to member
gabeweisz d929a05
fix bug with save preprocessed
gabeweisz fe609bc
add param to save
gabeweisz 079e566
add metadata processor
gabeweisz 96beb20
rename func for metadata extraction
gabeweisz c99a789
Merge branch 'main' of https://github.com/amd-aig-aima/TraceLens into…
gabeweisz 3e02d4f
add code for jax-based tree in tracetotree
gabeweisz b410d21
update docs
gabeweisz 8e37ba8
update exports
gabeweisz 9d0a501
merge main
gabeweisz 51ce3f6
Apply suggestions from code review
gabeweisz 94ee162
Update TraceLens/TreePerf/jax_analyses.py
gabeweisz f31772f
Update TraceLens/util.py
gabeweisz 573b965
fix trace_to_tree for jax
gabeweisz 431c32b
remove premature jax integration
gabeweisz 1b18805
tree perf updates will go in a different PR
gabeweisz 1cc2a4e
Merge branch 'gw_jax_tree' of https://github.com/amd-aig-aima/TraceLe…
gabeweisz f1a996b
remove premature jax integration
gabeweisz 0695fe1
fix preprocessing
gabeweisz 30a635a
give the metadata skipper a better name
gabeweisz 072800b
fix doc
gabeweisz 730f77d
wip
gabeweisz 47a15d0
add workaround for older python
gabeweisz 4f926ed
add fp32 and fp64 support to gemm extraction
gabeweisz b7b0e6b
fix gemm type generation
gabeweisz ef7f667
support types directly in operand
gabeweisz 34889d7
gemm_perf_metrics just about works
gabeweisz 55bcfc2
use raw strings for all regex, change 'B' to 'GEMM Batch' for all GEMMs
gabeweisz 25036a0
universally list operator batch as 'Op B'
gabeweisz 8581318
Merge branch 'main' into gw_jax_tree
ajassani 45cfe78
update doc and add arch to helper
gabeweisz ad22775
remove unwanted change
gabeweisz d253230
merge
gabeweisz 98a768c
merge other copilot fixes
gabeweisz d599279
fix error due to changing f string to raw string
gabeweisz 2bf32cb
change to rf string
gabeweisz 6dc34be
revert Op B to B
ajassani 246100e
Merge branch 'main' of https://github.com/amd-aig-aima/TraceLens into…
gabeweisz fe7c793
restore jax parameter
gabeweisz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| from .tree_perf import TreePerfAnalyzer | ||
| from .gpu_event_analyser import GPUEventAnalyser, PytorchGPUEventAnalyser, JaxGPUEventAnalyser | ||
| from .jax_analyses import JaxAnalyses | ||
| from .jax_analyses import JaxAnalyses, JaxProfileProcessor | ||
|
|
||
| __all__ = ["TreePerfAnalyzer", "GPUEventAnalyser", "PytorchGPUEventAnalyser", "JaxGPUEventAnalyser", "JaxAnalyses"] | ||
| __all__ = ["TreePerfAnalyzer", "GPUEventAnalyser", "PytorchGPUEventAnalyser", "JaxGPUEventAnalyser", "JaxAnalyses", "JaxProfileProcessor"] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this jax variable does not seem to be defined here