diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30e160dd2408..6c28b48f06d6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -6006,9 +6006,10 @@ def __call__( query, key, value = query.float(), key.float(), value.float() - value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0) - scores = torch.matmul(value, key) - hidden_states = torch.matmul(scores, query) + with torch.autocast(device_type=hidden_states.device.type, enabled=False): + value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0) + scores = torch.matmul(value, key) + hidden_states = torch.matmul(scores, query) hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15) hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)