-
Notifications
You must be signed in to change notification settings - Fork 558
Open
Labels
bugSomething isn't workingSomething isn't workingstablehloStableHLO related workStableHLO related work
Description
❓ Questions and Help
I’m trying to generate a StableHLO graph from a PyTorch/XLA model that uses the operator torch.ops.xla.flash_attention
(and other similar attention layer ops) on a custom XLA device backend (not TPU or GPU). I want to export the op to StableHLO IR and run it through PJRT.
This is my sample code
def test_torch_flash_attn():
class FlashAttnModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.b = StableHLOCompositeBuilder("custom.flash_attention", {})
def forward(self, q, k, v, is_causal=True):
# Mark inputs for composite region
q, k, v = self.b.mark_inputs(q, k, v)
# Call XLA flash attention op
out = torch.ops.xla.flash_attention(q, k, v, is_causal)
# Mark output
out = self.b.mark_outputs(out)
return out
mod = FlashAttnModule()
T, H, L, D = 2, 4, 128, 64
q = torch.randn(T, H, L, D)
k = torch.randn_like(q)
v = torch.randn_like(q)
mod = mod.to("xla")
q = q.to("xla")
k = k.to("xla")
v = v.to("xla")
exported = torch.export.export(mod, (q, k, v))
# Convert to StableHLO
stable = stablehlo.exported_program_to_stablehlo(exported)
# Print StableHLO IR
print(stable.get_stablehlo_text("forward"))
I am getting an error
ValueError: Only interpret mode is supported on CPU backend
I am using
torch 2.8.0.dev20250618+cpu
torch-xla 2.8.0+gitaaff959
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstablehloStableHLO related workStableHLO related work