Skip to content

Commit 3d013a8

Browse files
committed
fix a dtype issue when evaluating the sana transformer with a float16 autocast context
1 parent a647682 commit 3d013a8

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6006,9 +6006,10 @@ def __call__(
60066006

60076007
query, key, value = query.float(), key.float(), value.float()
60086008

6009-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
6010-
scores = torch.matmul(value, key)
6011-
hidden_states = torch.matmul(scores, query)
6009+
with torch.autocast(device_type=hidden_states.device.type, enabled=False):
6010+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
6011+
scores = torch.matmul(value, key)
6012+
hidden_states = torch.matmul(scores, query)
60126013

60136014
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
60146015
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)

0 commit comments

Comments
 (0)