Skip to content

Commit cf5505a

Browse files
virginiafdezVirginia FernandezKumoLiu
authored
Diffusion Model Encoder has an output layer set in forward method and this leads to problems (#8578)
Fixes #8577 . ### Description This pull request sets the out layer of DiffusionModelEncoder in the init method. This requires the inclusion of input_shape parameter in the init method to calculate the input dimension to the last linear layer. The output spatial shape derivation is a bit baroque, but allows for the otherwise not very pleasant odd spatial dimensions at input. ### Types of changes - [x] Non breaking change (fix or new feature that would cause existing functionality to change). **The tutorial that was failing will need to provide this input parameter, but input_shape is now defaulted to the notebook dims.** - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: Virginia Fernandez <[email protected]> Co-authored-by: Virginia Fernandez <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent 6327a86 commit cf5505a

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

monai/networks/nets/diffusion_model_unet.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333

3434
import math
3535
from collections.abc import Sequence
36+
from functools import reduce
3637
from typing import Optional
3738

39+
import numpy as np
3840
import torch
3941
from torch import nn
4042

@@ -1882,6 +1884,7 @@ class DiffusionModelEncoder(nn.Module):
18821884
spatial_dims: number of spatial dimensions.
18831885
in_channels: number of input channels.
18841886
out_channels: number of output channels.
1887+
input_shape: spatial shape of the input (without batch and channel dims).
18851888
num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
18861889
channels: tuple of block output channels.
18871890
attention_levels: list of levels to add attention.
@@ -1901,6 +1904,7 @@ def __init__(
19011904
spatial_dims: int,
19021905
in_channels: int,
19031906
out_channels: int,
1907+
input_shape: Sequence[int] = (64, 64),
19041908
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
19051909
channels: Sequence[int] = (32, 64, 64, 64),
19061910
attention_levels: Sequence[bool] = (False, False, True, True),
@@ -2007,7 +2011,14 @@ def __init__(
20072011

20082012
self.down_blocks.append(down_block)
20092013

2010-
self.out: Optional[nn.Module] = None
2014+
for _ in channels:
2015+
input_shape = [int(np.ceil(i_ / 2)) for i_ in input_shape]
2016+
2017+
last_dim_flattened = int(reduce(lambda x, y: x * y, input_shape) * channels[-1])
2018+
2019+
self.out: Optional[nn.Module] = nn.Sequential(
2020+
nn.Linear(last_dim_flattened, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
2021+
)
20112022

20122023
def forward(
20132024
self,
@@ -2052,10 +2063,9 @@ def forward(
20522063
h = h.reshape(h.shape[0], -1)
20532064

20542065
# 5. out
2055-
if self.out is None:
2056-
self.out = nn.Sequential(
2057-
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
2058-
)
2066+
self.out = nn.Sequential(
2067+
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
2068+
)
20592069
output: torch.Tensor = self.out(h)
20602070

20612071
return output

0 commit comments

Comments
 (0)