-
Notifications
You must be signed in to change notification settings - Fork 672
Description
Issue: Style Encoder Not Updated & Generator Gradients Leak During Discriminator Training
Affected File
Description
1. Style Encoder Not Updated During Generator Training
When training the generator with reference images:
g_loss, g_losses_ref = compute_g_loss(
nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks
)
self._reset_grad()
g_loss.backward()
optims.generator.step()However, this step does not include an optimizer step for the style encoder. As a result, its parameters are not updated, which breaks reference-based style translation.
2. Generator Gradients Updated During Discriminator Training
In the same file, the discriminator is trained using:
out = nets.discriminator(x_fake, y_trg)This allows gradients to flow back into the generator. To prevent this, x_fake should be detached:
out = nets.discriminator(x_fake.detach(), y_trg)Without .detach(), the generator receives gradient updates during discriminator training, which is incorrect.
Proposed Fixes
Include the style encoder in the optimization step when training the generator with reference images.
Use x_fake.detach() when passing fake images to the discriminator.
Please clarify if this behavior is intentional. Otherwise, these issues need correction for proper training dynamics.