Skip to content

Feat/qkv stride#159

Merged
ajassani merged 6 commits intomainfrom
feat/qkv_stride
May 15, 2025
Merged

Feat/qkv stride#159
ajassani merged 6 commits intomainfrom
feat/qkv_stride

Conversation

@ajassani
Copy link
Copy Markdown
Collaborator

No description provided.

@ajassani ajassani requested a review from Copilot May 15, 2025 17:11
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for extracting and returning query, key, and value strides from the event arguments in both the performance report generator and the flash attention model.

  • Extract Input Strides, convert each to a tuple, and include them in the returned parameter dict in generate_perf_report_megatron_lm.py.
  • Unpack and include q_stride, k_stride, and v_stride in flash_attention.get_param_details within perf_model.py.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
examples/custom_workflows/generate_perf_report_megatron_lm.py Added logic to slice, tuple-convert, and return q/k/v strides
TraceLens/PerfModel/perf_model.py Introduced q/k/v stride extraction and added them to the params
Comments suppressed due to low confidence (3)

examples/custom_workflows/generate_perf_report_megatron_lm.py:33

  • [nitpick] The variable names use plural q_strides, k_strides, v_strides but in perf_model.py you use singular (q_stride). Consider unifying the naming convention across modules for consistency.
q_strides, k_strides, v_strides = strides[q_idx: q_idx+3]

TraceLens/PerfModel/perf_model.py:853

  • [nitpick] You already unpacked k_shape above; consider using k_shape here (e.g. _, N_KV, H_KV, _ = k_shape) to avoid mixing direct indexing with named variables.
_, N_KV, H_KV, _ = input_dims[1]

examples/custom_workflows/generate_perf_report_megatron_lm.py:32

  • New behavior for extracting Input Strides is introduced; ensure you add or update tests to cover these stride values and their conversion to tuples.
strides = event['args']['Input Strides']

@ajassani ajassani merged commit d996ada into main May 15, 2025
@ajassani ajassani deleted the feat/qkv_stride branch May 15, 2025 17:12
lauri9 pushed a commit that referenced this pull request Jun 11, 2025
This PR adds support for extracting and returning query, key, and value strides from the event arguments in both the performance report generator and the flash attention model.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants