Is there a way to create FLASH attention in xFormerConfig? #915
YoelShoshan
started this conversation in
General
Replies: 2 comments 1 reply
-
|
Figured it out myself, after the following is added you can use "memory_efficient_scaled_dot_product" in the config. Also:
import logging
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch import nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import scaled_dot_product_attention
from xformers import ops as xops
logger = logging.getLogger("xformers")
@dataclass
class MemoryEfficientScaledDotProductConfig(AttentionConfig):
causal: Optional[bool]
seq_len: Optional[int]
to_seq_len: Optional[int]
@register_attention("memory_efficient_scaled_dot_product", MemoryEfficientScaledDotProductConfig)
class MemoryEfficientScaledDotProduct(Attention):
r"""
A memory efficient version of the
Implementing the Scaled Dot-Product attention proposed in
`Attention is all you need`_, Vaswani et al.
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
"""
mask: Optional[AttentionMask]
def __init__(
self,
dropout: float = 0.0,
causal: bool = False,
seq_len: Optional[int] = None,
to_seq_len: Optional[int] = None,
*args,
**kwargs,
):
super().__init__()
#self.attn_drop = nn.Dropout(dropout, inplace=False)
self.dropout = dropout
self.causal = causal
self.seq_len = seq_len
if causal and seq_len is not None:
self.mask = AttentionMask.make_causal(seq_len, to_seq_len)
else:
self.mask = None
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
r"""
att_mask A 2D or 3D mask which ignores attention at certain positions.
- If the mask is boolean, a value of True will keep the value,
while a value of False will mask the value.
Key padding masks (dimension: batch x sequence length) and attention masks
(dimension: sequence length x sequence length OR batch x sequence length x sequence length)
can be combined and passed in here. Method maybe_merge_masks provided in the utils can be
used for that merging.
- If the mask has the float type, then an additive mask is expected (masked values are -inf)
"""
# Convenience, create an attention mask if a tensor was passed
if att_mask is not None and isinstance(att_mask, torch.Tensor):
# By default we don't know of the causality, and a check would be expensive
att_mask = (
AttentionMask.from_bool(att_mask)
if att_mask.dtype == torch.bool
else AttentionMask(att_mask, is_causal=False)
)
# Handle a possibly deferred causal mask handling
mask = self.mask
if self.causal and self.mask is None:
mask = AttentionMask.make_causal(
seq_len=q.shape[-2],
to_seq_len=q.shape[-2],
device=q.device,
dtype=q.dtype,
)
# Merge the optional causal mask and the user-provided mask
if mask is not None:
mask = mask.to(dtype=q.dtype, device=q.device)
att_mask = att_mask + mask if att_mask is not None else mask
# Try to handle a case where the sequence is smaller than the mask
if (
att_mask is not None
and q.shape[-2] == k.shape[-2]
and q.shape[-2] < att_mask.shape[1]
):
if isinstance(att_mask, AttentionMask):
att_mask = att_mask.make_crop(seq_len=q.shape[-2])
else:
logger.error(
"Mismatching sparse attention mask and sequence length."
+ " Please pad the inputs or adjust the attention mask"
)
raise NotImplementedError
if att_mask is not None:
att_mask = att_mask.values.expand(q.shape[0], -1, -1)
# Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
y = xops.memory_efficient_attention(
query=q, key=k, value=v, p=self.dropout, attn_bias=att_mask,
)
return y |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Hi, |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a way to modify the config to trigger usage of FLASH attention?
https://github.com/facebookresearch/xformers/blob/main/HOWTO.md#model-factory
for example, by modifying something here?
also, is there any example of a simple transformer (either encoder-only or encoder-decoder) in xformer that uses FLASH attention?
Beta Was this translation helpful? Give feedback.
All reactions