Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/infer_compiler_registry/register_diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
from diffusers.models.resnet import Upsample2D
if diffusers_version >= version.parse("0.24.00"):
from diffusers.models.resnet import SpatioTemporalResBlock
from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel
from diffusers.models.attention import TemporalBasicTransformerBlock

if diffusers_version >= version.parse("0.26.00"):
from diffusers.models.unets.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel,
)
from diffusers.models.transformers.transformer_temporal import TransformerSpatioTemporalModel
else:
from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel
from diffusers.models.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel,
)
Expand All @@ -46,6 +47,18 @@
)
else:
from diffusers.models.autoencoder_kl_temporal_decoder import TemporalDecoder

from .spatio_temporal_oflow import SpatioTemporalResBlock as SpatioTemporalResBlockOflow
from .spatio_temporal_oflow import TemporalDecoder as TemporalDecoderOflow
from .spatio_temporal_oflow import (
TransformerSpatioTemporalModel as TransformerSpatioTemporalModelOflow,
)
from .spatio_temporal_oflow import (
TemporalBasicTransformerBlock as TemporalBasicTransformerBlockOflow,
)
from .spatio_temporal_oflow import (
UNetSpatioTemporalConditionModel as UNetSpatioTemporalConditionModelOflow,
)

from .attention_processor_oflow import Attention as AttentionOflow
from .attention_processor_oflow import AttnProcessor as AttnProcessorOflow
Expand All @@ -56,17 +69,6 @@
from .unet_2d_blocks_oflow import UpBlock2D as UpBlock2DOflow
from .resnet_oflow import Upsample2D as Upsample2DOflow
from .transformer_2d_oflow import Transformer2DModel as Transformer2DModelOflow
from .spatio_temporal_oflow import SpatioTemporalResBlock as SpatioTemporalResBlockOflow
from .spatio_temporal_oflow import TemporalDecoder as TemporalDecoderOflow
from .spatio_temporal_oflow import (
TransformerSpatioTemporalModel as TransformerSpatioTemporalModelOflow,
)
from .spatio_temporal_oflow import (
TemporalBasicTransformerBlock as TemporalBasicTransformerBlockOflow,
)
from .spatio_temporal_oflow import (
UNetSpatioTemporalConditionModel as UNetSpatioTemporalConditionModelOflow,
)

# For CI
if diffusers_version >= version.parse("0.24.00"):
Expand Down
Loading