Skip to content

feat: transformer engine ver 1 GEMM ops te_gemm_ts added#125

Merged
ajassani merged 4 commits intomainfrom
feat/te-ver1-gemm-ops
May 5, 2025
Merged

feat: transformer engine ver 1 GEMM ops te_gemm_ts added#125
ajassani merged 4 commits intomainfrom
feat/te-ver1-gemm-ops

Conversation

@olehtika
Copy link
Copy Markdown
Contributor

-Add transformer engine ver 1 GEMM ops tex_ts::te_gemm_ts GEMM kernel computation
-Modify GEMM base class init to parse transpose information before matrix dimension calculation since it is needed

@olehtika
Copy link
Copy Markdown
Contributor Author

Refs #28

def bytes(self):
dtype_A_B = self.param_details['dtype_A_B']
if dtype_A_B[0] != dtype_A_B[1]:
raise ValueError(f"Data types of A and B are different: {dtype_A_B}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the data types have to be the same? With fp8 we might have dtype_A=fp8 and dtype_B=bf8 and this should be fine. You probably need to generalize this bytes calculation for different A and B data types.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base class GEMM seems to support different datatypes so there should be no reason to require the same datatype. Is this correct @ajassani?

bpe_output=self.bpe)

def flops_bwd(self):
raise NotImplementedError("Backward pass for aten::addmm is not defined.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward pass for aten::addmm is not defined -> Backward pass for tex_ts::te_gemm_ts is not defined.

Has this flops_bmw been implemented for any ops? Are there examples of how this is calculated?

def flops_bwd(self):
raise NotImplementedError("Backward pass for aten::addmm is not defined.")
def bytes_bwd(self, bytes_per_element):
raise NotImplementedError("Backward pass for aten::addmm is not defined.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward pass for aten::addmm is not defined -> Backward pass for tex_ts::te_gemm_ts is not defined

@ajassani ajassani merged commit 2f94270 into main May 5, 2025
@ajassani ajassani deleted the feat/te-ver1-gemm-ops branch May 5, 2025 15:11
lauri9 pushed a commit that referenced this pull request Jun 11, 2025
-Add transformer engine ver 1 GEMM ops tex_ts::te_gemm_ts GEMM kernel
computation
-Modify GEMM base class init to parse transpose information before
matrix dimension calculation since it is needed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants