8
8
from alonet .raft .corr import CorrBlock , AlternateCorrBlock
9
9
from alonet .raft .update import BasicUpdateBlock
10
10
from alonet .raft .extractor import BasicEncoder
11
+ from alonet .common .abstract_classes import abstract_attribute , check_abstract_attribute_instanciation , super_new
11
12
from alonet .raft .utils .utils import coords_grid , upflow8
12
13
from aloscene import Flow , Frame
13
14
@@ -32,25 +33,31 @@ class RAFTBase(nn.Module):
32
33
"""
33
34
34
35
# should be overriden in subclasses
35
- hidden_dim = None
36
- context_dim = None
37
- corr_levels = None
38
- corr_radius = None
36
+ hidden_dim = abstract_attribute ()
37
+ context_dim = abstract_attribute ()
38
+ corr_levels = abstract_attribute ()
39
+ corr_radius = abstract_attribute ()
40
+ out_plane = abstract_attribute ()
41
+
42
+ # checks that all abstract attribute are instanciated in child class
43
+ def __new__ (cls , * args , ** kwargs ):
44
+ check_abstract_attribute_instanciation (cls )
45
+ return super_new (RAFTBase , cls , * args , ** kwargs )
39
46
40
47
def __init__ (
41
48
self ,
42
49
fnet ,
43
50
cnet ,
44
51
update_block ,
45
- alternate_corr = False ,
46
52
weights : str = None ,
53
+ corr_block = CorrBlock ,
47
54
device : torch .device = torch .device ("cpu" ),
48
55
):
49
56
super ().__init__ ()
50
57
self .fnet = fnet
51
58
self .cnet = cnet
52
59
self .update_block = update_block
53
- self .alternate_corr = alternate_corr
60
+ self .corr_block = corr_block
54
61
55
62
if weights is not None :
56
63
weights_from_original_repo = ["raft-things" , "raft-chairs" , "raft-small" , "raft-kitti" , "raft-sintel" ]
@@ -83,7 +90,7 @@ def build_update_block(self, update_cls=BasicUpdateBlock):
83
90
"""
84
91
Build RAFT Update Block
85
92
"""
86
- return update_cls (self .corr_levels , self .corr_radius , hidden_dim = self .hdim )
93
+ return update_cls (self .corr_levels , self .corr_radius , hidden_dim = self .hdim , out_planes = self . out_plane )
87
94
88
95
def freeze_bn (self ):
89
96
for m in self .modules ():
@@ -101,16 +108,28 @@ def initialize_flow(self, img):
101
108
102
109
def upsample_flow (self , flow , mask ):
103
110
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
104
- N , _ , H , W = flow .shape
105
- mask = mask .view (N , 1 , 9 , 8 , 8 , H , W )
106
- mask = torch .softmax (mask , dim = 2 )
111
+ if mask is None :
112
+ return upflow8 (flow )
113
+ else :
114
+ N , _ , H , W = flow .shape
115
+ mask = mask .view (N , 1 , 9 , 8 , 8 , H , W )
116
+ mask = torch .softmax (mask , dim = 2 )
117
+
118
+ up_flow = F .unfold (8 * flow , [3 , 3 ], padding = 1 )
119
+ up_flow = up_flow .view (N , self .out_plane , 9 , 1 , 1 , H , W )
107
120
108
- up_flow = F .unfold (8 * flow , [3 , 3 ], padding = 1 )
109
- up_flow = up_flow .view (N , 2 , 9 , 1 , 1 , H , W )
121
+ up_flow = torch .sum (mask * up_flow , dim = 2 )
122
+ up_flow = up_flow .permute (0 , 1 , 4 , 2 , 5 , 3 )
123
+ return up_flow .reshape (N , self .out_plane , 8 * H , 8 * W )
110
124
111
- up_flow = torch .sum (mask * up_flow , dim = 2 )
112
- up_flow = up_flow .permute (0 , 1 , 4 , 2 , 5 , 3 )
113
- return up_flow .reshape (N , 2 , 8 * H , 8 * W )
125
+ def forward_heads (self , m_outputs , only_last = False ):
126
+ if not only_last :
127
+ for out_dict in m_outputs :
128
+ out_dict ["up_flow" ] = self .upsample_flow (out_dict ["flow" ], out_dict ["up_mask" ])
129
+
130
+ else :
131
+ m_outputs [- 1 ]["up_flow" ] = self .upsample_flow (m_outputs [- 1 ]["flow" ], m_outputs [- 1 ]["up_mask" ])
132
+ return m_outputs
114
133
115
134
def forward (self , frame1 : Frame , frame2 : Frame , iters = 12 , flow_init = None , only_last = False ):
116
135
"""Estimate optical flow between pair of frames
@@ -140,26 +159,17 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l
140
159
frame1 = frame1 .as_tensor ()
141
160
frame2 = frame2 .as_tensor ()
142
161
143
- # frame1 = frame1.contiguous()
144
- # frame2 = frame2.contiguous()
145
-
146
- hdim = self .hidden_dim
147
- cdim = self .context_dim
148
-
149
162
# run the feature network
150
-
151
163
fmap1 , fmap2 = self .fnet ([frame1 , frame2 ])
152
164
153
165
fmap1 = fmap1 .float ()
154
166
fmap2 = fmap2 .float ()
155
- if self .alternate_corr :
156
- corr_fn = AlternateCorrBlock (fmap1 , fmap2 , radius = self .corr_radius )
157
- else :
158
- corr_fn = CorrBlock (fmap1 , fmap2 , radius = self .corr_radius )
167
+
168
+ corr_fn = self .corr_block (fmap1 , fmap2 , radius = self .corr_radius )
159
169
160
170
# run the context network
161
171
cnet = self .cnet (frame1 )
162
- net , inp = torch .split (cnet , [hdim , cdim ], dim = 1 )
172
+ net , inp = torch .split (cnet , [self . hdim , self . cdim ], dim = 1 )
163
173
net = torch .tanh (net )
164
174
inp = torch .relu (inp )
165
175
@@ -168,42 +178,31 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l
168
178
if flow_init is not None :
169
179
coords1 = coords1 + flow_init
170
180
171
- flow_predictions = []
181
+ m_outputs = list ()
182
+
172
183
for itr in range (iters ):
173
184
coords1 = coords1 .detach ()
174
185
corr = corr_fn (coords1 ) # index correlation volume
175
-
176
186
flow = coords1 - coords0
177
187
net , up_mask , delta_flow = self .update_block (net , inp , corr , flow )
178
-
179
- # F(t+1) = F(t) + \Delta(t)
180
188
coords1 = coords1 + delta_flow
189
+ m_outputs .append (
190
+ {"flow" : coords1 - coords0 , "hidden_state" : net , "up_mask" : up_mask , "delta_flow" : delta_flow }
191
+ )
181
192
182
- # upsample predictions
183
- if up_mask is None :
184
- flow_up = upflow8 (coords1 - coords0 )
185
- else :
186
- flow_up = self .upsample_flow (coords1 - coords0 , up_mask )
187
-
188
- flow_predictions .append (flow_up )
189
-
190
- if only_last :
191
- flow_low = coords1 - coords0
192
- return flow_low , flow_up
193
- else :
194
- return flow_predictions
193
+ return self .forward_heads (m_outputs , only_last = only_last )
195
194
196
195
@torch .no_grad ()
197
- def inference (self , forward_out , only_last = False ):
196
+ def inference (self , m_outputs , only_last = False ):
197
+ def generate_frame (out_dict ):
198
+ # flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W"))
199
+ flow_up = Flow (out_dict ["up_flow" ], names = ("B" , "C" , "H" , "W" ))
200
+ return flow_up
201
+
198
202
if only_last :
199
- flow_low , flow_up = forward_out
200
- flow_low = Flow (flow_low , names = ("B" , "C" , "H" , "W" ))
201
- flow_up = Flow (flow_up , names = ("B" , "C" , "H" , "W" ))
202
- return flow_low , flow_up
203
- elif isinstance (forward_out , list ):
204
- return [Flow (flow , names = ("B" , "C" , "H" , "W" )) for flow in forward_out ]
203
+ return generate_frame (m_outputs [- 1 ])
205
204
else :
206
- return Flow ( forward_out , names = ( "B" , "C" , "H" , "W" ))
205
+ return [ generate_frame ( out_dict ) for out_dict in m_outputs ]
207
206
208
207
209
208
class RAFT (RAFTBase ):
@@ -234,6 +233,7 @@ class RAFT(RAFTBase):
234
233
context_dim = 128
235
234
corr_levels = 4
236
235
corr_radius = 4
236
+ out_plane = 2
237
237
238
238
def __init__ (self , dropout = 0 , ** kwargs ):
239
239
self .dropout = dropout
@@ -272,8 +272,10 @@ def __init__(self, dropout=0, **kwargs):
272
272
273
273
# inference
274
274
with torch .no_grad ():
275
- flow = raft .forward (frame1 , frame2 )[- 1 ] # keep only last stage flow estimation
275
+ m_outputs = raft .forward (frame1 , frame2 ) # keep only last stage flow estimation
276
+ output = raft .inference (m_outputs )
277
+
278
+ flow = output [- 1 ]
276
279
flow = padder .unpad (flow ) # unpad to original image resolution
277
- flow = raft .inference (flow )
278
280
flow = flow .detach ().cpu ()
279
281
flow .get_view ().render ()
0 commit comments