diff --git a/TraceLens/PerfModel/perf_model.py b/TraceLens/PerfModel/perf_model.py index e4db703f..11e44f0c 100644 --- a/TraceLens/PerfModel/perf_model.py +++ b/TraceLens/PerfModel/perf_model.py @@ -1194,6 +1194,37 @@ def get_param_details(event): return {"B": B, "N_Q": N_Q, "H_Q": H_Q, "N_KV": N_KV, "H_KV": H_KV, "d_h": d_h, "dropout": dropout_p, "causal": is_causal, "flash_impl": False} +class aten__scaled_dot_product_flash_attention(SDPA): + + @staticmethod + def get_param_details(event): + # the order of arguments for aten::_scaled_dot_product_flash_attention is: + # query: Tensor + # key: Tensor + # value: Tensor + # dropout_p: float + # is_causal: bool + # return_debug_mask: bool + # * + # scale: Optional[float] + input_dims = event['args']['Input Dims'] + concrete_inputs = event['args']['Concrete Inputs'] + q_shape, k_shape, v_shape = input_dims[0], input_dims[1], input_dims[2] + B, H_Q, N_Q, d_h = q_shape + assert k_shape == v_shape, f"Key and value shapes are different: {k_shape} != {v_shape}" + _, H_KV, N_KV, _ = input_dims[1] + dropout_p = 0.0 + if concrete_inputs[3] not in ('', 'None'): + try: + dropout_p = float(concrete_inputs[3]) + except (ValueError, TypeError): + pass + is_causal = concrete_inputs[4].lower() == 'true' if concrete_inputs[4] not in ('', 'None') else False + # scale = float(concrete_inputs[5]) if concrete_inputs[5] not in ('', 'None') else None + + return {"B": B, "N_Q": N_Q, "H_Q": H_Q, "N_KV": N_KV, "H_KV": H_KV, "d_h": d_h, + "dropout": dropout_p, "causal": is_causal, "flash_impl": True} + class UnaryElementwise: def __init__(self, event, arch=None, python_path=None): diff --git a/TraceLens/PerfModel/torch_op_mapping.py b/TraceLens/PerfModel/torch_op_mapping.py index 386a96ce..6c75a2ff 100644 --- a/TraceLens/PerfModel/torch_op_mapping.py +++ b/TraceLens/PerfModel/torch_op_mapping.py @@ -43,6 +43,7 @@ 'flash_attn::_flash_attn_forward': perf_model.flash_attention, 'aten::_scaled_dot_product_cudnn_attention': perf_model.aten__scaled_dot_product_cudnn_attention, 'aten::_scaled_dot_product_efficient_attention': perf_model.aten__scaled_dot_product_efficient_attention, + 'aten::_scaled_dot_product_flash_attention': perf_model.aten__scaled_dot_product_flash_attention, 'aten::convolution': perf_model.aten_conv, }