-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[Fix] enable_xformers_memory_efficient_attention() in Flux Pipeline #12337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
src/diffusers/models/attention.py
Outdated
@@ -241,7 +241,10 @@ def set_use_memory_efficient_attention_xformers( | |||
op_fw, op_bw = attention_op | |||
dtype, *_ = op_fw.SUPPORTED_DTYPES | |||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) | |||
_ = xops.memory_efficient_attention(q, q, q) | |||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly you could this at the top when its first imported as looks like only the one function is used.
if is_xformers_available():
import xformers as xops
xformers_attn_fn = getattr(xops, "memory_efficient_attention", None) \
or getattr(xops.ops, "memory_efficient_attention")
else:
xformers_attn_fn = None
As there is already many trys here and conditions in this block of code. Also seems like there is additional checks in the function, but not in the code you contributed.
I did notice this breaking change in xformers maybe there is better way also to check which is installed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i fix the trys issue here now . I am thinking to remove already present try
. Can you tell me which version of xformers
is suitable here and also check the code below
else:
import importlib.metadata
version = importlib.metadata.version("xformers")
if tuple(map(int, version.split("."))) < (0, 0, 32):
raise ImportError(f"xformers>=0.0.32 is required for memory efficient attention, but found {version}")
# Make sure we can run the memory efficient attention
dtype = None
if attention_op is not None:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xops.memory_efficient_attention(q, q, q)
self.set_attention_backend("xformers")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think both ways are supported depending on the version of xformers.
I don't want to block so possibly your original is most straight forward currently as looking more at the code it probably a need a bit of a wider refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then it is fixed for now in latest commit. Thank you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah you mean force people to be on the update to date xformers ?
What does this PR do?
Fixes #12334 enable_xformers_memory_efficient_attention() in Flux Pipeline
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed.
@DN6