feat: transformer engine ver 1 GEMM ops te_gemm_ts added#125
Conversation
|
Refs #28 |
TraceLens/PerfModel/perf_model.py
Outdated
| def bytes(self): | ||
| dtype_A_B = self.param_details['dtype_A_B'] | ||
| if dtype_A_B[0] != dtype_A_B[1]: | ||
| raise ValueError(f"Data types of A and B are different: {dtype_A_B}") |
There was a problem hiding this comment.
Why do the data types have to be the same? With fp8 we might have dtype_A=fp8 and dtype_B=bf8 and this should be fine. You probably need to generalize this bytes calculation for different A and B data types.
There was a problem hiding this comment.
Base class GEMM seems to support different datatypes so there should be no reason to require the same datatype. Is this correct @ajassani?
TraceLens/PerfModel/perf_model.py
Outdated
| bpe_output=self.bpe) | ||
|
|
||
| def flops_bwd(self): | ||
| raise NotImplementedError("Backward pass for aten::addmm is not defined.") |
There was a problem hiding this comment.
Backward pass for aten::addmm is not defined -> Backward pass for tex_ts::te_gemm_ts is not defined.
Has this flops_bmw been implemented for any ops? Are there examples of how this is calculated?
TraceLens/PerfModel/perf_model.py
Outdated
| def flops_bwd(self): | ||
| raise NotImplementedError("Backward pass for aten::addmm is not defined.") | ||
| def bytes_bwd(self, bytes_per_element): | ||
| raise NotImplementedError("Backward pass for aten::addmm is not defined.") |
There was a problem hiding this comment.
Backward pass for aten::addmm is not defined -> Backward pass for tex_ts::te_gemm_ts is not defined
-Add transformer engine ver 1 GEMM ops tex_ts::te_gemm_ts GEMM kernel computation -Modify GEMM base class init to parse transpose information before matrix dimension calculation since it is needed
-Add transformer engine ver 1 GEMM ops tex_ts::te_gemm_ts GEMM kernel computation
-Modify GEMM base class init to parse transpose information before matrix dimension calculation since it is needed