Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import comfy
from .attention import CrossAttention as CrossAttention1f
from .attention import SpatialTransformer as SpatialTransformer1f
from .attention import SpatialVideoTransformer as SpatialVideoTransformer1f
from .linear import Linear as Linear1f
from .util import AlphaBlender as AlphaBlender1f
from .deep_cache_unet import DeepCacheUNet
from .deep_cache_unet import FastDeepCacheUNet

Expand All @@ -26,17 +28,21 @@
torch2of_class_map = {
comfy.ldm.modules.attention.CrossAttention: CrossAttention1f,
comfy.ldm.modules.attention.SpatialTransformer: SpatialTransformer1f,
comfy.ldm.modules.attention.SpatialVideoTransformer: SpatialVideoTransformer1f,
comfy.ldm.modules.diffusionmodules.util.AlphaBlender: AlphaBlender1f,
comfy_ops_Linear: Linear1f,
AttnBlock: AttnBlock1f,
}

from .openaimodel import Upsample as Upsample1f
from .openaimodel import UNetModel as UNetModel1f
from .openaimodel import VideoResBlock as VideoResBlock1f

torch2of_class_map.update(
{
comfy.ldm.modules.diffusionmodules.openaimodel.Upsample: Upsample1f,
comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel: UNetModel1f,
comfy.ldm.modules.diffusionmodules.openaimodel.VideoResBlock: VideoResBlock1f,
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from typing import Optional, Any

from onediff.infer_compiler.transform import proxy_class
from onediff.infer_compiler.transform import transform_mgr
from einops import rearrange, repeat
from abc import abstractmethod

onediff_comfy = transform_mgr.transform_package("comfy")

ops = onediff_comfy.ops.disable_weight_init
timestep_embedding = onediff_comfy.ldm.modules.diffusionmodules.util.timestep_embedding


def exists(val):
Expand Down Expand Up @@ -214,3 +222,202 @@ def forward(self, x, context=None, value=None, mask=None):
raise NotImplementedError

return self.to_out(out)


class SpatialVideoTransformer(
proxy_class(comfy.ldm.modules.attention.SpatialVideoTransformer)
):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None,
device=None,
operations=ops,
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
dtype=dtype,
device=device,
operations=operations,
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period

time_mix_d_head = d_head
n_time_mix_heads = n_heads

time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)

inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim

self.time_stack = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
# timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(self.depth)
]
)

assert len(self.time_stack) == len(self.transformer_blocks)

self.use_spatial_context = use_spatial_context
self.in_channels = in_channels

time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
operations.Linear(
self.in_channels, time_embed_dim, dtype=dtype, device=device
),
nn.SiLU(),
operations.Linear(
time_embed_dim, self.in_channels, dtype=dtype, device=device
),
)

self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)

def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={},
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"

if time_context is None:
time_context = context
time_context_first_timestep = time_context[::timesteps]
# time_context = repeat(
# time_context_first_timestep, "b ... -> (b n) ...", n=h * w
# )
# Rewrite for onediff SVD dynamic shape
time_context = torch._C.broadcast_dim_like(
time_context_first_timestep[None, :], x.flatten(2, 3), dim=0, like_dim=2,
).flatten(0, 1)

elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
time_context = torch._C.broadcast_dim_like(
time_context_first_timestep[None, :],
x.flatten(2, 3),
dim=0,
like_dim=2,
)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
# time_context = time_context.unsqueeze(1)

x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
# x = rearrange(x, "b c h w -> b (h w) c")
# Rewrite for onediff SVD dynamic shape
x = x.permute(0, 2, 3, 1).flatten(1, 2)
if self.use_linear:
x = self.proj_in(x)

num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
# num_frames = rearrange(num_frames, "b t -> (b t)")
# Rewrite for onediff SVD dynamic shape
num_frames = num_frames.flatten()
t_emb = timestep_embedding(
num_frames,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
).to(x.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]

for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
x, context=spatial_context, transformer_options=transformer_options,
)

x_mix = x
x_mix = x_mix + emb

B, S, C = x_mix.shape
# x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
# Rewrite for onediff SVD dynamic shape
b = B // timesteps
x_mix = x_mix.unflatten(0, shape=(b, -1)).permute(0, 2, 1, 3).flatten(0, 1)
x_mix = mix_block(x_mix, context=time_context) # TODO: transformer_options
# x_mix = rearrange(
# x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
# )
# Rewrite for onediff SVD dynamic shape
x_mix = x_mix.unflatten(0, shape=(b, -1)).permute(0, 2, 1, 3).flatten(0, 1)

x = self.time_mixer(
x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator
)

if self.use_linear:
x = self.proj_out(x)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
# Rewrite for onediff SVD dynamic shape
x = x.reshape_as(x_in.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import oneflow.nn.functional as F
from onediff.infer_compiler.transform import proxy_class
from onediff.infer_compiler.transform import transform_mgr
from einops import rearrange
from abc import abstractmethod

onediff_comfy = transform_mgr.transform_package("comfy")

ops = onediff_comfy.ops.disable_weight_init
ResBlock = onediff_comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock
checkpoint = onediff_comfy.ldm.modules.diffusionmodules.util.checkpoint


class Upsample(proxy_class(comfy.ldm.modules.diffusionmodules.openaimodel.Upsample)):
# https://github.com/comfyanonymous/ComfyUI/blob/b0aab1e4ea3dfefe09c4f07de0e5237558097e22/comfy/ldm/modules/diffusionmodules/openaimodel.py#L82
Expand Down Expand Up @@ -151,3 +157,70 @@ def forward(
return self.id_predictor(h)
else:
return self.out(h)


class VideoResBlock(
proxy_class(comfy.ldm.modules.diffusionmodules.openaimodel.VideoResBlock)
):
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)

emb_out = None
if not self.skip_t_emb:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
h = out_norm(h)
if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1)
h *= 1 + scale
h += shift
h = out_rest(h)
else:
if emb_out is not None:
if self.exchange_temb_dims:
# emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
# Rewrite for onediff SVD dynamic shape
emb_out = emb_out.permute(0, 2, 1, 3, 4)
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h

def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator=None,
) -> th.Tensor:
# Rewrite for onediff SVD dynamic shape
# x = super().forward(x, emb)
x = checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)

# x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
batch_frames, _, _, _ = x.shape
batch_size = batch_frames // num_video_frames
x_mix = x.unflatten(0, shape=(batch_size, -1)).permute(0, 2, 1, 3, 4)
# x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = x.unflatten(0, shape=(batch_size, -1)).permute(0, 2, 1, 3, 4)

# x = self.time_stack(
# x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
# )
x = self.time_stack(x, emb.unflatten(0, shape=(batch_size, -1)))

x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
# x = rearrange(x, "b c t h w -> (b t) c h w")
x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
return x
Loading