Skip to content

Commit 717b6b6

Browse files
authored
Merge pull request #128 from Visual-Behavior/raft_refacto
Raft refacto
2 parents 67acff4 + 56b54e4 commit 717b6b6

File tree

5 files changed

+121
-83
lines changed

5 files changed

+121
-83
lines changed

alonet/common/abstract_classes.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Inspired by https://stackoverflow.com/a/50381071/14647356
3+
4+
Tools to create abstract classes
5+
"""
6+
7+
8+
class AbstractAttribute:
9+
pass
10+
11+
12+
def abstract_attribute(obj=None):
13+
if obj is None:
14+
obj = AbstractAttribute()
15+
obj.__is_abstract_attribute__ = True
16+
return obj
17+
18+
19+
def check_abstract_attribute_instanciation(cls):
20+
abstract_attributes = {
21+
name for name in dir(cls) if getattr(getattr(cls, name), "__is_abstract_attribute__", False)
22+
}
23+
if abstract_attributes:
24+
raise NotImplementedError(
25+
f"Can't instantiate abstract class {type(cls).__name__}."
26+
f"The following abstract attributes shoud be instanciated: {', '.join(abstract_attributes)}"
27+
)
28+
29+
30+
def super_new(abstract_cls, cls, *args, **kwargs):
31+
__new__ = super(abstract_cls, cls).__new__
32+
if __new__ is object.__new__:
33+
return __new__(cls)
34+
else:
35+
return __new__(cls, *args, **kwargs)

alonet/raft/criterion.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,33 @@ def __init__(self):
1111

1212
# loss from RAFT implementation
1313
@staticmethod
14-
def sequence_loss(flow_preds, flow_gt, valid=None, gamma=0.8, max_flow=400, compute_per_iter=False):
14+
def sequence_loss(m_outputs, flow_gt, valid=None, gamma=0.8, max_flow=400, compute_per_iter=False):
1515
"""Loss function defined over sequence of flow predictions"""
16-
n_predictions = len(flow_preds)
16+
n_predictions = len(m_outputs)
1717
flow_loss = 0.0
1818

1919
# exlude invalid pixels and extremely large diplacements
20-
mag = torch.sum(flow_gt ** 2, dim=1).sqrt()
20+
mag = torch.sum(flow_gt ** 2, dim=1, keepdim=True).sqrt()
2121
if valid is None:
2222
valid = torch.ones_like(mag, dtype=torch.bool)
2323
else:
2424
valid = (valid >= 0.5) & (mag < max_flow)
2525
for i in range(n_predictions):
26+
m_dict = m_outputs[i]
2627
i_weight = gamma ** (n_predictions - i - 1)
27-
i_loss = (flow_preds[i] - flow_gt).abs()
28+
i_loss = (m_dict["up_flow"] - flow_gt).abs()
2829
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
2930

3031
if compute_per_iter:
3132
epe_per_iter = []
32-
for flow_p in flow_preds:
33-
epe = torch.sum((flow_p - flow_gt) ** 2, dim=1).sqrt()
33+
for i in range(n_predictions):
34+
m_dict = m_outputs[i]
35+
epe = torch.sum((m_dict["up_flow"] - flow_gt) ** 2, dim=1).sqrt()
3436
epe = epe.view(-1)[valid.view(-1)]
3537
epe_per_iter.append(epe)
3638
else:
3739
epe_per_iter = None
38-
39-
epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
40+
epe = torch.sum((m_outputs[-1]["up_flow"] - flow_gt) ** 2, dim=1).sqrt()
4041
epe = epe.view(-1)[valid.view(-1)]
4142

4243
metrics = {
@@ -50,7 +51,6 @@ def sequence_loss(flow_preds, flow_gt, valid=None, gamma=0.8, max_flow=400, comp
5051

5152
def forward(self, m_outputs, frame1, use_valid=True, compute_per_iter=False):
5253
assert isinstance(frame1, aloscene.Frame)
53-
flow_preds = m_outputs
5454
flow_gt = [f.batch() for f in frame1.flow["flow_forward"]]
5555
flow_gt = torch.cat(flow_gt, dim=0)
5656
# occlusion mask -- not used in raft original repo
@@ -65,6 +65,6 @@ def forward(self, m_outputs, frame1, use_valid=True, compute_per_iter=False):
6565
else:
6666
valid = None
6767
flow_loss, metrics, epe_per_iter = RAFTCriterion.sequence_loss(
68-
flow_preds, flow_gt, valid, compute_per_iter=compute_per_iter
68+
m_outputs, flow_gt, valid, compute_per_iter=compute_per_iter
6969
)
7070
return flow_loss, metrics, epe_per_iter

alonet/raft/raft.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from alonet.raft.corr import CorrBlock, AlternateCorrBlock
99
from alonet.raft.update import BasicUpdateBlock
1010
from alonet.raft.extractor import BasicEncoder
11+
from alonet.common.abstract_classes import abstract_attribute, check_abstract_attribute_instanciation, super_new
1112
from alonet.raft.utils.utils import coords_grid, upflow8
1213
from aloscene import Flow, Frame
1314

@@ -32,25 +33,31 @@ class RAFTBase(nn.Module):
3233
"""
3334

3435
# should be overriden in subclasses
35-
hidden_dim = None
36-
context_dim = None
37-
corr_levels = None
38-
corr_radius = None
36+
hidden_dim = abstract_attribute()
37+
context_dim = abstract_attribute()
38+
corr_levels = abstract_attribute()
39+
corr_radius = abstract_attribute()
40+
out_plane = abstract_attribute()
41+
42+
# checks that all abstract attribute are instanciated in child class
43+
def __new__(cls, *args, **kwargs):
44+
check_abstract_attribute_instanciation(cls)
45+
return super_new(RAFTBase, cls, *args, **kwargs)
3946

4047
def __init__(
4148
self,
4249
fnet,
4350
cnet,
4451
update_block,
45-
alternate_corr=False,
4652
weights: str = None,
53+
corr_block=CorrBlock,
4754
device: torch.device = torch.device("cpu"),
4855
):
4956
super().__init__()
5057
self.fnet = fnet
5158
self.cnet = cnet
5259
self.update_block = update_block
53-
self.alternate_corr = alternate_corr
60+
self.corr_block = corr_block
5461

5562
if weights is not None:
5663
weights_from_original_repo = ["raft-things", "raft-chairs", "raft-small", "raft-kitti", "raft-sintel"]
@@ -83,7 +90,7 @@ def build_update_block(self, update_cls=BasicUpdateBlock):
8390
"""
8491
Build RAFT Update Block
8592
"""
86-
return update_cls(self.corr_levels, self.corr_radius, hidden_dim=self.hdim)
93+
return update_cls(self.corr_levels, self.corr_radius, hidden_dim=self.hdim, out_planes=self.out_plane)
8794

8895
def freeze_bn(self):
8996
for m in self.modules():
@@ -101,16 +108,28 @@ def initialize_flow(self, img):
101108

102109
def upsample_flow(self, flow, mask):
103110
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
104-
N, _, H, W = flow.shape
105-
mask = mask.view(N, 1, 9, 8, 8, H, W)
106-
mask = torch.softmax(mask, dim=2)
111+
if mask is None:
112+
return upflow8(flow)
113+
else:
114+
N, _, H, W = flow.shape
115+
mask = mask.view(N, 1, 9, 8, 8, H, W)
116+
mask = torch.softmax(mask, dim=2)
117+
118+
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
119+
up_flow = up_flow.view(N, self.out_plane, 9, 1, 1, H, W)
107120

108-
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
109-
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
121+
up_flow = torch.sum(mask * up_flow, dim=2)
122+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
123+
return up_flow.reshape(N, self.out_plane, 8 * H, 8 * W)
110124

111-
up_flow = torch.sum(mask * up_flow, dim=2)
112-
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
113-
return up_flow.reshape(N, 2, 8 * H, 8 * W)
125+
def forward_heads(self, m_outputs, only_last=False):
126+
if not only_last:
127+
for out_dict in m_outputs:
128+
out_dict["up_flow"] = self.upsample_flow(out_dict["flow"], out_dict["up_mask"])
129+
130+
else:
131+
m_outputs[-1]["up_flow"] = self.upsample_flow(m_outputs[-1]["flow"], m_outputs[-1]["up_mask"])
132+
return m_outputs
114133

115134
def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_last=False):
116135
"""Estimate optical flow between pair of frames
@@ -140,26 +159,17 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l
140159
frame1 = frame1.as_tensor()
141160
frame2 = frame2.as_tensor()
142161

143-
# frame1 = frame1.contiguous()
144-
# frame2 = frame2.contiguous()
145-
146-
hdim = self.hidden_dim
147-
cdim = self.context_dim
148-
149162
# run the feature network
150-
151163
fmap1, fmap2 = self.fnet([frame1, frame2])
152164

153165
fmap1 = fmap1.float()
154166
fmap2 = fmap2.float()
155-
if self.alternate_corr:
156-
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.corr_radius)
157-
else:
158-
corr_fn = CorrBlock(fmap1, fmap2, radius=self.corr_radius)
167+
168+
corr_fn = self.corr_block(fmap1, fmap2, radius=self.corr_radius)
159169

160170
# run the context network
161171
cnet = self.cnet(frame1)
162-
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
172+
net, inp = torch.split(cnet, [self.hdim, self.cdim], dim=1)
163173
net = torch.tanh(net)
164174
inp = torch.relu(inp)
165175

@@ -168,42 +178,31 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l
168178
if flow_init is not None:
169179
coords1 = coords1 + flow_init
170180

171-
flow_predictions = []
181+
m_outputs = list()
182+
172183
for itr in range(iters):
173184
coords1 = coords1.detach()
174185
corr = corr_fn(coords1) # index correlation volume
175-
176186
flow = coords1 - coords0
177187
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
178-
179-
# F(t+1) = F(t) + \Delta(t)
180188
coords1 = coords1 + delta_flow
189+
m_outputs.append(
190+
{"flow": coords1 - coords0, "hidden_state": net, "up_mask": up_mask, "delta_flow": delta_flow}
191+
)
181192

182-
# upsample predictions
183-
if up_mask is None:
184-
flow_up = upflow8(coords1 - coords0)
185-
else:
186-
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
187-
188-
flow_predictions.append(flow_up)
189-
190-
if only_last:
191-
flow_low = coords1 - coords0
192-
return flow_low, flow_up
193-
else:
194-
return flow_predictions
193+
return self.forward_heads(m_outputs, only_last=only_last)
195194

196195
@torch.no_grad()
197-
def inference(self, forward_out, only_last=False):
196+
def inference(self, m_outputs, only_last=False):
197+
def generate_frame(out_dict):
198+
# flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W"))
199+
flow_up = Flow(out_dict["up_flow"], names=("B", "C", "H", "W"))
200+
return flow_up
201+
198202
if only_last:
199-
flow_low, flow_up = forward_out
200-
flow_low = Flow(flow_low, names=("B", "C", "H", "W"))
201-
flow_up = Flow(flow_up, names=("B", "C", "H", "W"))
202-
return flow_low, flow_up
203-
elif isinstance(forward_out, list):
204-
return [Flow(flow, names=("B", "C", "H", "W")) for flow in forward_out]
203+
return generate_frame(m_outputs[-1])
205204
else:
206-
return Flow(forward_out, names=("B", "C", "H", "W"))
205+
return [generate_frame(out_dict) for out_dict in m_outputs]
207206

208207

209208
class RAFT(RAFTBase):
@@ -234,6 +233,7 @@ class RAFT(RAFTBase):
234233
context_dim = 128
235234
corr_levels = 4
236235
corr_radius = 4
236+
out_plane = 2
237237

238238
def __init__(self, dropout=0, **kwargs):
239239
self.dropout = dropout
@@ -272,8 +272,10 @@ def __init__(self, dropout=0, **kwargs):
272272

273273
# inference
274274
with torch.no_grad():
275-
flow = raft.forward(frame1, frame2)[-1] # keep only last stage flow estimation
275+
m_outputs = raft.forward(frame1, frame2) # keep only last stage flow estimation
276+
output = raft.inference(m_outputs)
277+
278+
flow = output[-1]
276279
flow = padder.unpad(flow) # unpad to original image resolution
277-
flow = raft.inference(flow)
278280
flow = flow.detach().cpu()
279281
flow.get_view().render()

alonet/raft/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def validation_step(self, frames, batch_idx, dataloader_idx=None):
6767
def build_criterion(self):
6868
return RAFTCriterion()
6969

70-
def build_model(self, alternate_corr=False, weights=None, device="cpu", dropout=0):
71-
return alonet.raft.RAFT(alternate_corr=alternate_corr, weights=weights, device=device, dropout=dropout)
70+
def build_model(self, weights=None, device="cpu", dropout=0):
71+
return alonet.raft.RAFT(weights=weights, device=device, dropout=dropout)
7272

7373
def configure_optimizers(self, lr=4e-4, weight_decay=1e-4, epsilon=1e-8, numsteps=100000):
7474
params = self.model.parameters()

alonet/raft/update.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55

66
class FlowHead(nn.Module):
7-
def __init__(self, input_dim=128, hidden_dim=256):
7+
def __init__(self, input_dim=128, hidden_dim=256, out_planes=2):
88
super(FlowHead, self).__init__()
99
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10-
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
10+
self.conv2 = nn.Conv2d(hidden_dim, out_planes, 3, padding=1)
1111
self.relu = nn.ReLU(inplace=True)
1212

1313
def forward(self, x):
@@ -27,7 +27,7 @@ def forward(self, h, x):
2727
z = torch.sigmoid(self.convz(hx))
2828
r = torch.sigmoid(self.convr(hx))
2929
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
30-
30+
3131
h = (1 - z) * h + z * q
3232
return h
3333

@@ -62,11 +62,12 @@ def forward(self, h, x):
6262

6363

6464
class SmallMotionEncoder(nn.Module):
65-
def __init__(self, corr_levels, corr_radius):
65+
def __init__(self, corr_levels, corr_radius, out_planes=2):
6666
super(SmallMotionEncoder, self).__init__()
67-
cor_planes = corr_levels * (2 * corr_radius + 1) ** 2
67+
cor_planes = corr_levels * (2 * corr_radius + 1) ** out_planes
68+
6869
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
69-
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
70+
self.convf1 = nn.Conv2d(out_planes, 64, 7, padding=3)
7071
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
7172
self.conv = nn.Conv2d(128, 80, 3, padding=1)
7273

@@ -80,14 +81,14 @@ def forward(self, flow, corr):
8081

8182

8283
class BasicMotionEncoder(nn.Module):
83-
def __init__(self, corr_levels, corr_radius):
84+
def __init__(self, corr_levels, corr_radius, out_planes=2):
8485
super(BasicMotionEncoder, self).__init__()
85-
cor_planes = corr_levels * (2 * corr_radius + 1) ** 2
86+
cor_planes = corr_levels * (2 * corr_radius + 1) ** out_planes
8687
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
8788
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
88-
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
89+
self.convf1 = nn.Conv2d(out_planes, 128, 7, padding=3)
8990
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
90-
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
91+
self.conv = nn.Conv2d(64 + 192, 128 - out_planes, 3, padding=1)
9192

9293
def forward(self, flow, corr):
9394
cor = F.relu(self.convc1(corr))
@@ -101,11 +102,11 @@ def forward(self, flow, corr):
101102

102103

103104
class SmallUpdateBlock(nn.Module):
104-
def __init__(self, corr_levels, corr_radius, hidden_dim=96):
105+
def __init__(self, corr_levels, corr_radius, hidden_dim=96, out_planes=2):
105106
super(SmallUpdateBlock, self).__init__()
106-
self.encoder = SmallMotionEncoder(corr_levels, corr_radius)
107-
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
108-
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
107+
self.encoder = SmallMotionEncoder(corr_levels, corr_radius, out_planes=out_planes)
108+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=hidden_dim + 49)
109+
self.flow_head = FlowHead(hidden_dim, hidden_dim=128, out_planes=out_planes)
109110

110111
def forward(self, net, inp, corr, flow):
111112
motion_features = self.encoder(flow, corr)
@@ -117,11 +118,11 @@ def forward(self, net, inp, corr, flow):
117118

118119

119120
class BasicUpdateBlock(nn.Module):
120-
def __init__(self, corr_levels, corr_radius, hidden_dim=128, input_dim=128):
121+
def __init__(self, corr_levels, corr_radius, hidden_dim=128, input_dim=128, out_planes=2):
121122
super(BasicUpdateBlock, self).__init__()
122-
self.encoder = BasicMotionEncoder(corr_levels, corr_radius)
123+
self.encoder = BasicMotionEncoder(corr_levels, corr_radius, out_planes=out_planes)
123124
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
124-
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
125+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256, out_planes=out_planes)
125126

126127
self.mask = nn.Sequential(
127128
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 64 * 9, 1, padding=0)

0 commit comments

Comments
 (0)