Skip to content

Conversation

yaoyaoding
Copy link
Member

@yaoyaoding yaoyaoding commented Aug 22, 2025

This PR adds an example: attention with kv cache. The example implements an attention operator with the following parameters:

    q: torch.Tensor,  # fp16[batch_size, seqlen, num_heads, head_size]
    k_cache: torch.Tensor,  # fp16[num_blocks, page_block_size, num_heads_kv, head_size]
    v_cache: torch.Tensor,  # fp16[num_blocks, page_block_size, num_heads_kv, head_size]
    cache_seqlens: torch.Tensor,  # int32[batch_size]
    block_table: torch.Tensor,  # int32[batch_size, max_num_blocks_per_seq]

which is a simplified version of flash_attn_with_kvcache in Dao-AILab/flash-attention

Performance on RTX 4090:

   batch_size  seqlen_q  sum_seqlen_kv  num_heads  head_size  num_heads_kv        name  latency (ms)      tflops
0           1      4096           4096         32        128             8  flash-attn      1.130496  121.574026
1           1      4096           4096         32        128             8       tilus      0.975872  140.837073
2           1      1024           4096         32        128             8  flash-attn      0.467968   73.423267
3           1      1024           4096         32        128             8       tilus      0.437248   78.581809
4          16         1          65536         32        128             8  flash-attn      0.351680    1.526589
5          16         1          65536         32        128             8       tilus      0.445440    1.205260

We might need to implement split-k optimization for decode stage (e.g., seqlen_q = 1).

Minor changes needed by this example:

  1. Allow TensorElement in grid analyzer.

@yaoyaoding yaoyaoding marked this pull request as draft August 22, 2025 21:43
Copy link

copy-pr-bot bot commented Aug 22, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@yaoyaoding yaoyaoding marked this pull request as ready for review August 26, 2025 21:52
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
@yaoyaoding yaoyaoding force-pushed the yaoyao/attn-with-kvcache branch from ec71335 to 4334118 Compare August 26, 2025 21:53
@yaoyaoding yaoyaoding changed the title [WIP][Example] Add the attention example with kv-cache [Example] Add the attention example with kv-cache Aug 26, 2025
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
@yaoyaoding yaoyaoding requested a review from Copilot August 27, 2025 03:55
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new attention with KV cache example and makes several improvements to support it. The primary purpose is to demonstrate a high-performance attention mechanism implementation with key-value caching, which is crucial for efficient inference in large language models.

  • Implements a complete attention with KV cache operator similar to flash-attention
  • Updates the lambda function parameter handling in register_tensor initialization to use individual parameters instead of tuples
  • Adds support for TensorElement in grid analyzer and improves type handling in transpiler

Reviewed Changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
examples/attention_with_kvcache/attention_v1.py New attention with KV cache implementation with performance benchmarking
python/tilus/lang/script.py Updates register_tensor init parameter handling from tuple to individual parameters
python/tilus/lang/transpiler.py Improves type annotation handling and adds TensorElement support
python/tilus/ir/analyzers/grid_analyzer.py Adds TensorElement visitor method to grid analyzer
tests/lang/test_simple_mma_matmul.py Updates lambda to use individual parameters
tests/kernels/matmul/test_matmul_v0.py Updates lambda to use individual parameters
tests/instructions/test_reduce.py Updates lambda to use individual parameters
examples/attention/flash_attention_v*.py Updates lambda functions to use individual parameters
python/tilus/lang/modules/cuda.py Refactors shared layout creation to use function-based offset calculation
python/tilus/lang/instantiated_script.py Simplifies constant handling by removing Constant type requirements
python/tilus/ir/builders/stmt_builder.py Improves error message for broadcast operations
pyproject.toml Adds einops dependency for tensor operations

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@@ -161,7 +161,7 @@ def range(
step: Optional[Expr | int] = None,
/,
*,
unroll: Optional[Literal["all"] | int],
unroll: Optional[Literal["all"] | int] = None,
Copy link
Preview

Copilot AI Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value None for the unroll parameter changes the API behavior. Previously this parameter was required, but now it's optional. This could be a breaking change for existing code that relies on explicit unroll specification.

Suggested change
unroll: Optional[Literal["all"] | int] = None,
unroll: Optional[Literal["all"] | int],

Copilot uses AI. Check for mistakes.

Comment on lines +287 to +288
if resolved_annotation in (int, str, float):
rhs = resolved_annotation(rhs)
Copy link
Preview

Copilot AI Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type conversion resolved_annotation(rhs) could raise an exception if rhs cannot be converted to the target type. This should be wrapped in a try-catch block to provide a meaningful error message.

Suggested change
if resolved_annotation in (int, str, float):
rhs = resolved_annotation(rhs)
try:
rhs = resolved_annotation(rhs)
except (ValueError, TypeError) as e:
raise TilusProgramError(
self,
lhs,
f"Failed to convert value '{rhs}' to type '{resolved_annotation.__name__}' for variable '{var_name}': {e}"
)

Copilot uses AI. Check for mistakes.

@@ -874,6 +878,8 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any:
if isinstance(base, Sequence):
return base[indices]
elif isinstance(base, GlobalTensor):
if not isinstance(indices, Sequence):
Copy link
Preview

Copilot AI Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for Sequence could incorrectly identify strings as sequences since strings implement the Sequence protocol. This should explicitly exclude strings: if not isinstance(indices, Sequence) or isinstance(indices, str):.

Suggested change
if not isinstance(indices, Sequence):
if not isinstance(indices, Sequence) or isinstance(indices, str):

Copilot uses AI. Check for mistakes.

Comment on lines +156 to +159
mask = self.register_tensor(
dtype=boolean,
shape=[self.block_q, self.block_kv],
init=lambda i, j: i + q_offset + q_left_len >= j + kv_offset,
Copy link
Preview

Copilot AI Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The mask calculation uses multiple variables captured from the outer scope (q_offset, q_left_len, kv_offset). Consider extracting this logic into a named function or adding a comment explaining the causal mask logic for better readability.

Suggested change
mask = self.register_tensor(
dtype=boolean,
shape=[self.block_q, self.block_kv],
init=lambda i, j: i + q_offset + q_left_len >= j + kv_offset,
# Causal mask: ensures that each query position (i + q_offset + q_left_len)
# can only attend to key positions (j + kv_offset) that are not in the future.
def causal_mask(i, j):
return i + q_offset + q_left_len >= j + kv_offset
mask = self.register_tensor(
dtype=boolean,
shape=[self.block_q, self.block_kv],
init=causal_mask,

Copilot uses AI. Check for mistakes.

@yaoyaoding yaoyaoding force-pushed the yaoyao/attn-with-kvcache branch from d342973 to 1acf13f Compare August 27, 2025 15:20
@yaoyaoding yaoyaoding merged commit 6fff104 into main Aug 27, 2025
4 checks passed
@yaoyaoding yaoyaoding deleted the yaoyao/attn-with-kvcache branch September 3, 2025 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant