Skip to content

Generating StableHLO graph for torch.ops.xla.flash_attention on Custom XLA Device #9511

@mmanzoorTT

Description

@mmanzoorTT

❓ Questions and Help

I’m trying to generate a StableHLO graph from a PyTorch/XLA model that uses the operator torch.ops.xla.flash_attention (and other similar attention layer ops) on a custom XLA device backend (not TPU or GPU). I want to export the op to StableHLO IR and run it through PJRT.

This is my sample code

def test_torch_flash_attn():
    class FlashAttnModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.b = StableHLOCompositeBuilder("custom.flash_attention", {})

        def forward(self, q, k, v, is_causal=True):
            # Mark inputs for composite region
            q, k, v = self.b.mark_inputs(q, k, v)
            # Call XLA flash attention op
            out = torch.ops.xla.flash_attention(q, k, v, is_causal)
            # Mark output
            out = self.b.mark_outputs(out)
            return out

    mod = FlashAttnModule()
    T, H, L, D = 2, 4, 128, 64
    q = torch.randn(T, H, L, D)
    k = torch.randn_like(q)
    v = torch.randn_like(q)

    mod = mod.to("xla")
    q = q.to("xla")
    k = k.to("xla")
    v = v.to("xla")

    exported = torch.export.export(mod, (q, k, v))
    # Convert to StableHLO
    stable = stablehlo.exported_program_to_stablehlo(exported)

    # Print StableHLO IR
    print(stable.get_stablehlo_text("forward"))

I am getting an error

ValueError: Only interpret mode is supported on CPU backend

I am using

torch                              2.8.0.dev20250618+cpu
torch-xla                          2.8.0+gitaaff959

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstablehloStableHLO related work

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions