Skip to content

Preliminary ControlNet PR (WIP) #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffusion/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from diffusion.callbacks.log_latent_statistics import LogLatentStatistics
from diffusion.callbacks.nan_catcher import NaNCatcher
from diffusion.callbacks.scheduled_garbage_collector import ScheduledGarbageCollector
from diffusion.callbacks.assign_controlnet_weight import AssignControlNet

__all__ = [
'AssignControlNet',
'LogAutoencoderImages',
'LogDiffusionImages',
'LogLatentStatistics',
Expand Down
49 changes: 49 additions & 0 deletions diffusion/callbacks/assign_controlnet_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from composer import Callback, Logger, State
from composer.core import get_precision_context
from torch.nn.parallel import DistributedDataParallel
from diffusers import ControlNetModel, UNet2DConditionModel

class AssignControlNet(Callback):
"""Assigns Controlnet weights to the controlnet from the Unet after composer loads the checkpoint

Args:
use_fsdp: whether or not the model is FSDP wrapped
"""

def __init__(self, use_fsdp):
self.use_fsdp = use_fsdp

def process_controlnet(self, controlnet: ControlNetModel, unet: UNet2DConditionModel):
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())

if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())

if hasattr(controlnet, "add_embedding"):
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())

controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())

def after_load(self, state: State, logger: Logger):
# Get the model object if it has been wrapped by DDP to access the image generation function.
if isinstance(state.model, DistributedDataParallel):
model = state.model.module
else:
model = state.model

# Load checkpoint
if model.load_controlnet_from_composer:
with get_precision_context(state.precision):
if self.use_fsdp:
with FSDP.summon_full_params(model.unet, recurse = True, writeback = False):
with FSDP.summon_full_params(model.controlnet, recurse = True, writeback = True):
self.process_controlnet(model.controlnet, model.unet)

else:
self.process_controlnet(model.controlnet, model.unet)

6 changes: 4 additions & 2 deletions diffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""Diffusion models."""

from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl,
text_to_image_transformer)
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, stable_diffusion_2_controlnet,
stable_diffusion_xl_controlnet, text_to_image_transformer)
from diffusion.models.noop import NoOpModel
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.stable_diffusion import StableDiffusion
Expand All @@ -19,6 +19,8 @@
'PixelDiffusion',
'stable_diffusion_2',
'stable_diffusion_xl',
'stable_diffusion_2_controlnet',
'stable_diffusion_xl_controlnet',
'StableDiffusion',
'text_to_image_transformer',
]
Loading