Skip to content

Commit 5600d90

Browse files
committed
minor fix.
1 parent d3badb9 commit 5600d90

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

model/cfm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def sample(
9696
):
9797
self.eval()
9898

99-
if cond.device == torch.device('cuda'):
99+
if next(self.parameters()).dtype == torch.float16:
100100
cond = cond.half()
101101

102102
# raw wave

0 commit comments

Comments
 (0)