Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions TraceLens/PerfModel/perf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,37 @@ def get_param_details(event):
return {"B": B, "N_Q": N_Q, "H_Q": H_Q, "N_KV": N_KV, "H_KV": H_KV, "d_h": d_h,
"dropout": dropout_p, "causal": is_causal, "flash_impl": False}

class aten__scaled_dot_product_flash_attention(SDPA):

@staticmethod
def get_param_details(event):
# the order of arguments for aten::_scaled_dot_product_flash_attention is:
# query: Tensor
# key: Tensor
# value: Tensor
# dropout_p: float
# is_causal: bool
# return_debug_mask: bool
# *
# scale: Optional[float]
input_dims = event['args']['Input Dims']
concrete_inputs = event['args']['Concrete Inputs']
q_shape, k_shape, v_shape = input_dims[0], input_dims[1], input_dims[2]
B, H_Q, N_Q, d_h = q_shape
assert k_shape == v_shape, f"Key and value shapes are different: {k_shape} != {v_shape}"
_, H_KV, N_KV, _ = input_dims[1]
dropout_p = 0.0
if concrete_inputs[3] not in ('', 'None'):
try:
dropout_p = float(concrete_inputs[3])
except (ValueError, TypeError):
pass
Copy link

Copilot AI Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider logging a warning when the conversion of dropout_p fails in order to aid debugging instead of silently swallowing the error.

Suggested change
pass
logging.warning(f"Failed to convert dropout_p value '{concrete_inputs[3]}' to float.")

Copilot uses AI. Check for mistakes.
is_causal = concrete_inputs[4].lower() == 'true' if concrete_inputs[4] not in ('', 'None') else False
# scale = float(concrete_inputs[5]) if concrete_inputs[5] not in ('', 'None') else None
Copy link

Copilot AI Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing the commented scale conversion code or fully implementing it to avoid confusion in future maintenance.

Copilot uses AI. Check for mistakes.

return {"B": B, "N_Q": N_Q, "H_Q": H_Q, "N_KV": N_KV, "H_KV": H_KV, "d_h": d_h,
"dropout": dropout_p, "causal": is_causal, "flash_impl": True}

class UnaryElementwise:

def __init__(self, event, arch=None, python_path=None):
Expand Down
1 change: 1 addition & 0 deletions TraceLens/PerfModel/torch_op_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
'flash_attn::_flash_attn_forward': perf_model.flash_attention,
'aten::_scaled_dot_product_cudnn_attention': perf_model.aten__scaled_dot_product_cudnn_attention,
'aten::_scaled_dot_product_efficient_attention': perf_model.aten__scaled_dot_product_efficient_attention,
'aten::_scaled_dot_product_flash_attention': perf_model.aten__scaled_dot_product_flash_attention,
'aten::convolution': perf_model.aten_conv,
}

Expand Down