Skip to content

Commit 14f0a4f

Browse files
bottlerxFormers Bot
authored andcommitted
triton_splitk physical_page_idx back to int32 (fairinternal/xformers#1464)
__original_commit__ = fairinternal/xformers@88348fb
1 parent 89d3014 commit 14f0a4f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

xformers/ops/fmha/_triton/splitk_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def _fwd_kernel_splitK(
392392
logical_page_idx = logical_block_idx // BLOCKS_IN_PAGE
393393
physical_page_idx = tl.load(
394394
block_table + stride_blocktablesl * logical_page_idx
395-
).to(tl.int64) # Cast to int64 to avoid overflow when offset > 2^31
395+
).to(tl.int32)
396396
offset = physical_page_idx * PAGE_SIZE + block_offset_in_page * BLOCK_N
397397

398398
current_block_size = min(hi - start_n, BLOCK_N)

0 commit comments

Comments
 (0)