Skip to content

Commit af7213c

Browse files
authored
Merge branch 'ModelTC:main' into pd-dp-triton
2 parents c630c39 + 5b3e319 commit af7213c

13 files changed

+13
-13
lines changed

lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _fwd_kernel(
7676
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
7777
mask=(start_n + offs_n) < block_end_loc,
7878
other=0,
79-
)
79+
).to(tl.int64)
8080
off_k = kv_loc[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
8181
k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)
8282

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _fwd_kernel(
9191
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
9292
mask=(start_n + offs_n) < block_end_loc,
9393
other=0,
94-
)
94+
).to(tl.int64)
9595
off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d
9696
off_kv_rope = (
9797
kv_loc[None, :] * stride_kv_rope_bs

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _fwd_kernel_fp8(
9696
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
9797
mask=(start_n + offs_n) < block_end_loc,
9898
other=0,
99-
)
99+
).to(tl.int64)
100100
off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d
101101
off_kv_rope = (
102102
kv_loc[None, :] * stride_kv_rope_bs

lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _fwd_kernel_destindex_copy_kv(
3434
offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE)
3535
offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)
3636

37-
dest_index = tl.load(Dest_loc + cur_index)
37+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
3838

3939
kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :]
4040
kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :]

lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _fwd_kernel_destindex_copy_kv_fp8(
3636
offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE)
3737
offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)
3838

39-
dest_index = tl.load(Dest_loc + cur_index)
39+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
4040

4141
kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :]
4242
kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :]

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _fwd_kernel_flash_decode_stage1_padding(
105105
req_to_tokens_ptr + offs_n_new,
106106
mask=seq_n_mask,
107107
other=0,
108-
)
108+
).to(tl.int64)
109109
off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None]
110110
kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0)
111111
att_value = tl.dot(q, kv)

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _fwd_kernel_flash_decode_stage1_padding_fp8(
108108
req_to_tokens_ptr + offs_n_new,
109109
mask=seq_n_mask,
110110
other=0,
111-
)
111+
).to(tl.int64)
112112
off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None]
113113
kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0)
114114
off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None]

lightllm/models/deepseek2/triton_kernel/sample_kv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _sample_kv_kernel(
4444
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m,
4545
mask=offs_m < block_end_loc,
4646
other=0,
47-
)
47+
).to(tl.int64)
4848
off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :]
4949
off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :]
5050
kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0)

lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _fwd_kernel(
8484
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
8585
mask=(start_n + offs_n) < block_end_loc,
8686
other=0,
87-
)
87+
).to(tl.int64)
8888
off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
8989
k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)
9090
qk = tl.dot(q, k)

lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _fwd_kernel(
7070
Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n,
7171
mask=(start_n + offs_n) < cur_batch_seq_len,
7272
other=0,
73-
)
73+
).to(tl.int64)
7474
k = tl.load(
7575
k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0
7676
)

0 commit comments

Comments
 (0)