-
Notifications
You must be signed in to change notification settings - Fork 9
[Example] Add the attention example with kv-cache #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
ec71335
to
4334118
Compare
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a new attention with KV cache example and makes several improvements to support it. The primary purpose is to demonstrate a high-performance attention mechanism implementation with key-value caching, which is crucial for efficient inference in large language models.
- Implements a complete attention with KV cache operator similar to flash-attention
- Updates the lambda function parameter handling in register_tensor initialization to use individual parameters instead of tuples
- Adds support for TensorElement in grid analyzer and improves type handling in transpiler
Reviewed Changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
examples/attention_with_kvcache/attention_v1.py | New attention with KV cache implementation with performance benchmarking |
python/tilus/lang/script.py | Updates register_tensor init parameter handling from tuple to individual parameters |
python/tilus/lang/transpiler.py | Improves type annotation handling and adds TensorElement support |
python/tilus/ir/analyzers/grid_analyzer.py | Adds TensorElement visitor method to grid analyzer |
tests/lang/test_simple_mma_matmul.py | Updates lambda to use individual parameters |
tests/kernels/matmul/test_matmul_v0.py | Updates lambda to use individual parameters |
tests/instructions/test_reduce.py | Updates lambda to use individual parameters |
examples/attention/flash_attention_v*.py | Updates lambda functions to use individual parameters |
python/tilus/lang/modules/cuda.py | Refactors shared layout creation to use function-based offset calculation |
python/tilus/lang/instantiated_script.py | Simplifies constant handling by removing Constant type requirements |
python/tilus/ir/builders/stmt_builder.py | Improves error message for broadcast operations |
pyproject.toml | Adds einops dependency for tensor operations |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
@@ -161,7 +161,7 @@ def range( | |||
step: Optional[Expr | int] = None, | |||
/, | |||
*, | |||
unroll: Optional[Literal["all"] | int], | |||
unroll: Optional[Literal["all"] | int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value None
for the unroll
parameter changes the API behavior. Previously this parameter was required, but now it's optional. This could be a breaking change for existing code that relies on explicit unroll specification.
unroll: Optional[Literal["all"] | int] = None, | |
unroll: Optional[Literal["all"] | int], |
Copilot uses AI. Check for mistakes.
if resolved_annotation in (int, str, float): | ||
rhs = resolved_annotation(rhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type conversion resolved_annotation(rhs)
could raise an exception if rhs
cannot be converted to the target type. This should be wrapped in a try-catch block to provide a meaningful error message.
if resolved_annotation in (int, str, float): | |
rhs = resolved_annotation(rhs) | |
try: | |
rhs = resolved_annotation(rhs) | |
except (ValueError, TypeError) as e: | |
raise TilusProgramError( | |
self, | |
lhs, | |
f"Failed to convert value '{rhs}' to type '{resolved_annotation.__name__}' for variable '{var_name}': {e}" | |
) |
Copilot uses AI. Check for mistakes.
@@ -874,6 +878,8 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: | |||
if isinstance(base, Sequence): | |||
return base[indices] | |||
elif isinstance(base, GlobalTensor): | |||
if not isinstance(indices, Sequence): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for Sequence
could incorrectly identify strings as sequences since strings implement the Sequence protocol. This should explicitly exclude strings: if not isinstance(indices, Sequence) or isinstance(indices, str):
.
if not isinstance(indices, Sequence): | |
if not isinstance(indices, Sequence) or isinstance(indices, str): |
Copilot uses AI. Check for mistakes.
mask = self.register_tensor( | ||
dtype=boolean, | ||
shape=[self.block_q, self.block_kv], | ||
init=lambda i, j: i + q_offset + q_left_len >= j + kv_offset, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The mask calculation uses multiple variables captured from the outer scope (q_offset
, q_left_len
, kv_offset
). Consider extracting this logic into a named function or adding a comment explaining the causal mask logic for better readability.
mask = self.register_tensor( | |
dtype=boolean, | |
shape=[self.block_q, self.block_kv], | |
init=lambda i, j: i + q_offset + q_left_len >= j + kv_offset, | |
# Causal mask: ensures that each query position (i + q_offset + q_left_len) | |
# can only attend to key positions (j + kv_offset) that are not in the future. | |
def causal_mask(i, j): | |
return i + q_offset + q_left_len >= j + kv_offset | |
mask = self.register_tensor( | |
dtype=boolean, | |
shape=[self.block_q, self.block_kv], | |
init=causal_mask, |
Copilot uses AI. Check for mistakes.
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
d342973
to
1acf13f
Compare
This PR adds an example: attention with kv cache. The example implements an attention operator with the following parameters:
which is a simplified version of
flash_attn_with_kvcache
in Dao-AILab/flash-attentionPerformance on RTX 4090:
We might need to implement split-k optimization for decode stage (e.g., seqlen_q = 1).
Minor changes needed by this example: