-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Closed
Description
When using --alpha_mask with images with the background removed using rembg
. On commit 0d96e10
Traceback (most recent call last):
File "/mnt/900/builds/sd-scripts/train_network.py", line 1156, in <module>
trainer.train(args)
File "/mnt/900/builds/sd-scripts/train_network.py", line 919, in train
loss = apply_masked_loss(loss, batch)
File "/mnt/900/builds/sd-scripts/library/custom_train_functions.py", line 497, in apply_masked_loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
File "/mnt/900/builds/sd-scripts/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 3961, in interpolate
raise ValueError(
ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [1, 576, 960] and output size of torch.Size([72, 120]). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.
uncommented print line for context:
mask_image: torch.Size([2, 1, 1, 576, 960]), 0.6733689904212952
If I swap the following lines it works.
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
It does seem to work, to a degree, with the lines swapped.
I have seen others get it to work without having to modify this line so maybe some interaction with the dataset and the alpha_mask. Would be happy to try to isolate this.
Metadata
Metadata
Assignees
Labels
No labels