Skip to content

Conversation

SahilCarterr
Copy link
Contributor

@SahilCarterr SahilCarterr commented Sep 16, 2025

What does this PR do?

Fixes #12334 enable_xformers_memory_efficient_attention() in Flux Pipeline

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed.
@DN6

@@ -241,7 +241,10 @@ def set_use_memory_efficient_attention_xformers(
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xops.memory_efficient_attention(q, q, q)
try:
Copy link

@JoeGaffney JoeGaffney Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly you could this at the top when its first imported as looks like only the one function is used.

if is_xformers_available():
    import xformers as xops
    xformers_attn_fn = getattr(xops, "memory_efficient_attention", None) \
                       or getattr(xops.ops, "memory_efficient_attention")
else:
    xformers_attn_fn = None

As there is already many trys here and conditions in this block of code. Also seems like there is additional checks in the function, but not in the code you contributed.

I did notice this breaking change in xformers maybe there is better way also to check which is installed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i fix the trys issue here now . I am thinking to remove already present try . Can you tell me which version of xformers is suitable here and also check the code below

else:
                import importlib.metadata
                version = importlib.metadata.version("xformers")
                if tuple(map(int, version.split("."))) < (0, 0, 32):
                       raise ImportError(f"xformers>=0.0.32 is required for memory efficient attention, but found {version}")

                # Make sure we can run the memory efficient attention
                dtype = None
                if attention_op is not None:
                    op_fw, op_bw = attention_op
                    dtype, *_ = op_fw.SUPPORTED_DTYPES
                q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
                _ = xops.memory_efficient_attention(q, q, q)

                self.set_attention_backend("xformers")

@JoeGaffney

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both ways are supported depending on the version of xformers.

I don't want to block so possibly your original is most straight forward currently as looking more at the code it probably a need a bit of a wider refactor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it is fixed for now in latest commit. Thank you

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah you mean force people to be on the update to date xformers ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

a bug when i use enable_xformers_memory_efficient_attention()
2 participants