Skip to content

Commit c1425d1

Browse files
committed
change schema of efficient_attention_forward_ck to have optional tensor return
1 parent e1a17a9 commit c1425d1

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

xformers/csrc/attention/attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) {
2626
"xformers::efficient_attention_forward_ck(Tensor query, "
2727
"Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, "
2828
"Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, "
29-
"bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size, Tensor? block_tables, int? page_size) -> (Tensor, Tensor, int, int)"));
29+
"bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size, Tensor? block_tables, int? page_size) -> (Tensor, Tensor?, int, int)"));
3030
m.def(TORCH_SELECTIVE_SCHEMA(
3131
"xformers::efficient_attention_forward_decoder_ck(Tensor query, "
3232
"Tensor key, Tensor value, Tensor? seq_positions, float scale) -> Tensor"));

xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace {
4848
(Mode BMHK) With all the heads having the same seqlen
4949
(Mode 1MHK) `batch=1` with all tokens across batches concatenated
5050
*/
51-
std::tuple<at::Tensor, at::Tensor, int64_t, int64_t>
51+
std::tuple<at::Tensor, std::optional<at::Tensor>, int64_t, int64_t>
5252
efficient_attention_forward_ck(
5353
const at::Tensor& query, // [b, seqlen, num_heads_q, K]
5454
const at::Tensor& key, // [b, seqlen, num_heads_kv, K]
@@ -473,7 +473,7 @@ efficient_attention_forward_ck(
473473
(Mode BMHK) With all the heads having the same seqlen
474474
(Mode 1MHK) `batch=1` with all tokens across batches concatenated
475475
*/
476-
std::tuple<at::Tensor, at::Tensor, int64_t, int64_t>
476+
std::tuple<at::Tensor, std::optional<at::Tensor>, int64_t, int64_t>
477477
efficient_attention_forward_ck_meta(
478478
const at::Tensor& query, // [b, seqlen, num_heads_q, K]
479479
const at::Tensor& key, // [b, seqlen, num_heads_kv, K]

0 commit comments

Comments
 (0)