Skip to content

support gqa#147

Merged
ajassani merged 1 commit intomainfrom
feat/gqa
May 11, 2025
Merged

support gqa#147
ajassani merged 1 commit intomainfrom
feat/gqa

Conversation

@ajassani
Copy link
Copy Markdown
Collaborator

@ajassani ajassani commented May 11, 2025

This PR adds support for GQA by updating parameter naming and computation logic in performance-related functions.

Renames parameters (e.g., H → H_Q, d_k → d_h, N_K → N_KV) to align with GQA semantics.
Updates flops, bytes, and backward flops calculations and adjusts the extraction of tensor shapes in both forward and backward methods.

@ajassani ajassani requested a review from Copilot May 11, 2025 19:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for GQA by updating parameter naming and computation logic in performance-related functions.

  • Renames parameters (e.g., H → H_Q, d_k → d_h, N_K → N_KV) to align with GQA semantics.
  • Updates flops, bytes, and backward flops calculations and adjusts the extraction of tensor shapes in both forward and backward methods.
Comments suppressed due to low confidence (3)

TraceLens/PerfModel/perf_model.py:873

  • The dimension ordering in backward get_param_details appears inconsistent with the forward function, where Q shape is expected as (B, N_Q, H_Q, d_h). Verify that the reordering to (B, H_Q, N_Q, d_h) is intentional.
B, H_Q, N_Q, d_h = q_shape

TraceLens/PerfModel/perf_model.py:798

  • [nitpick] Using integer floor division in the flops calculation may lead to unintended truncation of critical precision; please confirm that this behavior is as intended for scaling purposes in the GQA computations.
flops_vgrad +=  B * N_KV * d_h * (H_Q//H_KV -1 )

TraceLens/PerfModel/perf_model.py:809

  • [nitpick] Double-check the intended effect of the integer division on H_Q relative to H_KV in the backward gradient computations, as a similar pattern appears with flops_vgrad.
flops_k_grad += B * N_KV * d_h * (H_Q//H_KV -1 )

@ajassani ajassani merged commit c2a0088 into main May 11, 2025
@ajassani ajassani deleted the feat/gqa branch May 11, 2025 19:18
lauri9 pushed a commit that referenced this pull request Jun 11, 2025
This PR adds support for GQA by updating parameter naming and
computation logic in performance-related functions.

Renames parameters (e.g., H → H_Q, d_k → d_h, N_K → N_KV) to align with
GQA semantics.
Updates flops, bytes, and backward flops calculations and adjusts the
extraction of tensor shapes in both forward and backward methods.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants