Skip to content

Commit 67f401a

Browse files
committed
set static graph flag when DDP ref kohya-ss#1363
1 parent 5d46cdf commit 67f401a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

sdxl_train_control_net_lllite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def train(args):
289289
# acceleratorがなんかよろしくやってくれるらしい
290290
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
291291

292+
if isinstance(unet, DDP):
293+
unet._set_static_graph() # avoid error for multiple use of the parameter
294+
292295
if args.gradient_checkpointing:
293296
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
294297
else:

0 commit comments

Comments
 (0)