From 78719d2822581f1903cd26b0096b70629207f256 Mon Sep 17 00:00:00 2001 From: Remi Agier Date: Sun, 3 Oct 2021 12:41:21 +0200 Subject: [PATCH 1/5] add out_plane --- alonet/raft/update.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/alonet/raft/update.py b/alonet/raft/update.py index 062f4437..1fdf1e14 100644 --- a/alonet/raft/update.py +++ b/alonet/raft/update.py @@ -4,10 +4,10 @@ class FlowHead(nn.Module): - def __init__(self, input_dim=128, hidden_dim=256): + def __init__(self, input_dim=128, hidden_dim=256, out_planes=2): super(FlowHead, self).__init__() self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) - self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, out_planes, 3, padding=1) self.relu = nn.ReLU(inplace=True) def forward(self, x): @@ -80,14 +80,14 @@ def forward(self, flow, corr): class BasicMotionEncoder(nn.Module): - def __init__(self, corr_levels, corr_radius): + def __init__(self, corr_levels, corr_radius, out_planes=2): super(BasicMotionEncoder, self).__init__() - cor_planes = corr_levels * (2 * corr_radius + 1) ** 2 + cor_planes = corr_levels * (2 * corr_radius + 1) ** out_planes self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) self.convc2 = nn.Conv2d(256, 192, 3, padding=1) - self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf1 = nn.Conv2d(out_planes, 128, 7, padding=3) self.convf2 = nn.Conv2d(128, 64, 3, padding=1) - self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - out_planes, 3, padding=1) def forward(self, flow, corr): cor = F.relu(self.convc1(corr)) @@ -117,11 +117,11 @@ def forward(self, net, inp, corr, flow): class BasicUpdateBlock(nn.Module): - def __init__(self, corr_levels, corr_radius, hidden_dim=128, input_dim=128): + def __init__(self, corr_levels, corr_radius, hidden_dim=128, input_dim=128, out_planes=2): super(BasicUpdateBlock, self).__init__() - self.encoder = BasicMotionEncoder(corr_levels, corr_radius) + self.encoder = BasicMotionEncoder(corr_levels, corr_radius, out_planes=out_planes) self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) - self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256, out_planes=out_planes) self.mask = nn.Sequential( nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 64 * 9, 1, padding=0) From 611676961c6d8edb827836baad4223f06696d116 Mon Sep 17 00:00:00 2001 From: Remi Agier Date: Sun, 3 Oct 2021 17:18:45 +0200 Subject: [PATCH 2/5] wip --- alonet/raft/criterion.py | 17 ++++--- alonet/raft/raft.py | 106 +++++++++++++++++++-------------------- alonet/raft/train.py | 4 +- alonet/raft/update.py | 17 ++++--- 4 files changed, 73 insertions(+), 71 deletions(-) diff --git a/alonet/raft/criterion.py b/alonet/raft/criterion.py index 547d7d77..3e9369f1 100644 --- a/alonet/raft/criterion.py +++ b/alonet/raft/criterion.py @@ -11,9 +11,9 @@ def __init__(self): # loss from RAFT implementation @staticmethod - def sequence_loss(flow_preds, flow_gt, valid=None, gamma=0.8, max_flow=400, compute_per_iter=False): + def sequence_loss(m_outputs, flow_gt, valid=None, gamma=0.8, max_flow=400, compute_per_iter=False): """Loss function defined over sequence of flow predictions""" - n_predictions = len(flow_preds) + n_predictions = len(m_outputs) flow_loss = 0.0 # exlude invalid pixels and extremely large diplacements @@ -23,20 +23,22 @@ def sequence_loss(flow_preds, flow_gt, valid=None, gamma=0.8, max_flow=400, comp else: valid = (valid >= 0.5) & (mag < max_flow) for i in range(n_predictions): + m_dict = m_outputs[i] i_weight = gamma ** (n_predictions - i - 1) - i_loss = (flow_preds[i] - flow_gt).abs() + i_loss = (m_dict["up_flow"] - flow_gt).abs() flow_loss += i_weight * (valid[:, None] * i_loss).mean() if compute_per_iter: epe_per_iter = [] - for flow_p in flow_preds: - epe = torch.sum((flow_p - flow_gt) ** 2, dim=1).sqrt() + for i in range(n_predictions): + m_dict = m_outputs[i] + epe = torch.sum((m_dict["up_flow"] - flow_gt) ** 2, dim=1).sqrt() epe = epe.view(-1)[valid.view(-1)] epe_per_iter.append(epe) else: epe_per_iter = None - epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() + epe = torch.sum((m_outputs[-1]["up_flow"] - flow_gt) ** 2, dim=1).sqrt() epe = epe.view(-1)[valid.view(-1)] metrics = { @@ -50,7 +52,6 @@ def sequence_loss(flow_preds, flow_gt, valid=None, gamma=0.8, max_flow=400, comp def forward(self, m_outputs, frame1, use_valid=True, compute_per_iter=False): assert isinstance(frame1, aloscene.Frame) - flow_preds = m_outputs flow_gt = [f.batch() for f in frame1.flow["flow_forward"]] flow_gt = torch.cat(flow_gt, dim=0) # occlusion mask -- not used in raft original repo @@ -65,6 +66,6 @@ def forward(self, m_outputs, frame1, use_valid=True, compute_per_iter=False): else: valid = None flow_loss, metrics, epe_per_iter = RAFTCriterion.sequence_loss( - flow_preds, flow_gt, valid, compute_per_iter=compute_per_iter + m_outputs, flow_gt, valid, compute_per_iter=compute_per_iter ) return flow_loss, metrics, epe_per_iter diff --git a/alonet/raft/raft.py b/alonet/raft/raft.py index 05653492..d127065d 100644 --- a/alonet/raft/raft.py +++ b/alonet/raft/raft.py @@ -37,6 +37,8 @@ class RAFTBase(nn.Module): corr_levels = None corr_radius = None + out_plane = None + def __init__( self, fnet, @@ -44,13 +46,14 @@ def __init__( update_block, alternate_corr=False, weights: str = None, - device: torch.device = torch.device("cpu"), + corr_block = CorrBlock, + device: torch.device = torch.device("cpu") ): super().__init__() self.fnet = fnet self.cnet = cnet self.update_block = update_block - self.alternate_corr = alternate_corr + self.corr_block = corr_block if weights is not None: weights_from_original_repo = ["raft-things", "raft-chairs", "raft-small", "raft-kitti", "raft-sintel"] @@ -83,7 +86,7 @@ def build_update_block(self, update_cls=BasicUpdateBlock): """ Build RAFT Update Block """ - return update_cls(self.corr_levels, self.corr_radius, hidden_dim=self.hdim) + return update_cls(self.corr_levels, self.corr_radius, hidden_dim=self.hdim, out_planes=self.out_plane) def freeze_bn(self): for m in self.modules(): @@ -101,16 +104,28 @@ def initialize_flow(self, img): def upsample_flow(self, flow, mask): """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" - N, _, H, W = flow.shape - mask = mask.view(N, 1, 9, 8, 8, H, W) - mask = torch.softmax(mask, dim=2) + if mask is None: + return upflow8(flow) + else: + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, out_plane, 9, 1, 1, H, W) - up_flow = F.unfold(8 * flow, [3, 3], padding=1) - up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, out_plane, 8 * H, 8 * W) - up_flow = torch.sum(mask * up_flow, dim=2) - up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) - return up_flow.reshape(N, 2, 8 * H, 8 * W) + def forward_heads(self, m_outputs, only_last=False): + if not only_last: + for out_dict in m_outputs: + out_dict["up_flow"] = self.upsample_flow(out_dict["flow"], out_dict["up_mask"]) + + else: + m_outputs[-1]["up_flow"] = self.upsample_flow(m_outputs[-1]["flow"], m_outputs[-1]["up_mask"]) + return m_outputs def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_last=False): """Estimate optical flow between pair of frames @@ -140,26 +155,17 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l frame1 = frame1.as_tensor() frame2 = frame2.as_tensor() - # frame1 = frame1.contiguous() - # frame2 = frame2.contiguous() - - hdim = self.hidden_dim - cdim = self.context_dim - # run the feature network - fmap1, fmap2 = self.fnet([frame1, frame2]) fmap1 = fmap1.float() fmap2 = fmap2.float() - if self.alternate_corr: - corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.corr_radius) - else: - corr_fn = CorrBlock(fmap1, fmap2, radius=self.corr_radius) + + corr_fn = self.corr_block(fmap1, fmap2, radius=self.corr_radius) # run the context network cnet = self.cnet(frame1) - net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net, inp = torch.split(cnet, [self.hdim, self.cdim], dim=1) net = torch.tanh(net) inp = torch.relu(inp) @@ -168,43 +174,35 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l if flow_init is not None: coords1 = coords1 + flow_init - flow_predictions = [] + m_outputs = list() + for itr in range(iters): coords1 = coords1.detach() corr = corr_fn(coords1) # index correlation volume + disp = coords1 - coords0 + net, up_mask, delta_disp = self.update_block(net, inp, corr, disp) + coords1 = coords1 + delta_disp + m_outputs.append({ + "flow":coords1 - coords0, + "hidden_state":net, + "up_mask":up_mask, + "delta_disp":delta_disp + }) + + return self.forward_heads(m_outputs, only_last=only_last) - flow = coords1 - coords0 - net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) - - # F(t+1) = F(t) + \Delta(t) - coords1 = coords1 + delta_flow - - # upsample predictions - if up_mask is None: - flow_up = upflow8(coords1 - coords0) - else: - flow_up = self.upsample_flow(coords1 - coords0, up_mask) - - flow_predictions.append(flow_up) - - if only_last: - flow_low = coords1 - coords0 + @torch.no_grad() + def inference(self, m_outputs, only_last=False): + def generate_frame(out_dict): + flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W")) + flow_up = Flow(out_dict["up_flow"], names=("B", "C", "H", "W")) return flow_low, flow_up - else: - return flow_predictions - @torch.no_grad() - def inference(self, forward_out, only_last=False): if only_last: - flow_low, flow_up = forward_out - flow_low = Flow(flow_low, names=("B", "C", "H", "W")) - flow_up = Flow(flow_up, names=("B", "C", "H", "W")) - return flow_low, flow_up - elif isinstance(forward_out, list): - return [Flow(flow, names=("B", "C", "H", "W")) for flow in forward_out] + return generate_frame(m_outputs[-1]) else: - return Flow(forward_out, names=("B", "C", "H", "W")) - + [generate_frame(out_dict) for out_dict in m_outputs] + class RAFT(RAFTBase): """ @@ -235,6 +233,8 @@ class RAFT(RAFTBase): corr_levels = 4 corr_radius = 4 + out_plane = 2 + def __init__(self, dropout=0, **kwargs): self.dropout = dropout @@ -272,7 +272,7 @@ def __init__(self, dropout=0, **kwargs): # inference with torch.no_grad(): - flow = raft.forward(frame1, frame2)[-1] # keep only last stage flow estimation + flow = raft.forward(frame1, frame2)["up_flow"][-1] # keep only last stage flow estimation flow = padder.unpad(flow) # unpad to original image resolution flow = raft.inference(flow) flow = flow.detach().cpu() diff --git a/alonet/raft/train.py b/alonet/raft/train.py index d2965112..d100b3a4 100644 --- a/alonet/raft/train.py +++ b/alonet/raft/train.py @@ -67,8 +67,8 @@ def validation_step(self, frames, batch_idx, dataloader_idx=None): def build_criterion(self): return RAFTCriterion() - def build_model(self, alternate_corr=False, weights=None, device="cpu", dropout=0): - return alonet.raft.RAFT(alternate_corr=alternate_corr, weights=weights, device=device, dropout=dropout) + def build_model(self, weights=None, device="cpu", dropout=0): + return alonet.raft.RAFT(weights=weights, device=device, dropout=dropout) def configure_optimizers(self, lr=4e-4, weight_decay=1e-4, epsilon=1e-8, numsteps=100000): params = self.model.parameters() diff --git a/alonet/raft/update.py b/alonet/raft/update.py index 1fdf1e14..7db44715 100644 --- a/alonet/raft/update.py +++ b/alonet/raft/update.py @@ -27,7 +27,7 @@ def forward(self, h, x): z = torch.sigmoid(self.convz(hx)) r = torch.sigmoid(self.convr(hx)) q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) - + h = (1 - z) * h + z * q return h @@ -62,11 +62,12 @@ def forward(self, h, x): class SmallMotionEncoder(nn.Module): - def __init__(self, corr_levels, corr_radius): + def __init__(self, corr_levels, corr_radius, out_planes=2): super(SmallMotionEncoder, self).__init__() - cor_planes = corr_levels * (2 * corr_radius + 1) ** 2 + cor_planes = corr_levels * (2 * corr_radius + 1) ** out_planes + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) - self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf1 = nn.Conv2d(out_planes, 64, 7, padding=3) self.convf2 = nn.Conv2d(64, 32, 3, padding=1) self.conv = nn.Conv2d(128, 80, 3, padding=1) @@ -101,11 +102,11 @@ def forward(self, flow, corr): class SmallUpdateBlock(nn.Module): - def __init__(self, corr_levels, corr_radius, hidden_dim=96): + def __init__(self, corr_levels, corr_radius, hidden_dim=96, out_planes=2): super(SmallUpdateBlock, self).__init__() - self.encoder = SmallMotionEncoder(corr_levels, corr_radius) - self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) - self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + self.encoder = SmallMotionEncoder(corr_levels, corr_radius, out_planes=out_planes) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=hidden_dim + 49) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128, out_planes=out_planes) def forward(self, net, inp, corr, flow): motion_features = self.encoder(flow, corr) From 7b147201503b8be23b559b1dde2887af2f00ad27 Mon Sep 17 00:00:00 2001 From: Remi Agier Date: Sun, 3 Oct 2021 17:48:45 +0200 Subject: [PATCH 3/5] fix --- alonet/raft/criterion.py | 3 +-- alonet/raft/raft.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/alonet/raft/criterion.py b/alonet/raft/criterion.py index 3e9369f1..a68ffafa 100644 --- a/alonet/raft/criterion.py +++ b/alonet/raft/criterion.py @@ -17,7 +17,7 @@ def sequence_loss(m_outputs, flow_gt, valid=None, gamma=0.8, max_flow=400, compu flow_loss = 0.0 # exlude invalid pixels and extremely large diplacements - mag = torch.sum(flow_gt ** 2, dim=1).sqrt() + mag = torch.sum(flow_gt ** 2, dim=1, keepdim=True).sqrt() if valid is None: valid = torch.ones_like(mag, dtype=torch.bool) else: @@ -37,7 +37,6 @@ def sequence_loss(m_outputs, flow_gt, valid=None, gamma=0.8, max_flow=400, compu epe_per_iter.append(epe) else: epe_per_iter = None - epe = torch.sum((m_outputs[-1]["up_flow"] - flow_gt) ** 2, dim=1).sqrt() epe = epe.view(-1)[valid.view(-1)] diff --git a/alonet/raft/raft.py b/alonet/raft/raft.py index d127065d..1ddaf033 100644 --- a/alonet/raft/raft.py +++ b/alonet/raft/raft.py @@ -112,11 +112,11 @@ def upsample_flow(self, flow, mask): mask = torch.softmax(mask, dim=2) up_flow = F.unfold(8 * flow, [3, 3], padding=1) - up_flow = up_flow.view(N, out_plane, 9, 1, 1, H, W) + up_flow = up_flow.view(N, self.out_plane, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) - return up_flow.reshape(N, out_plane, 8 * H, 8 * W) + return up_flow.reshape(N, self.out_plane, 8 * H, 8 * W) def forward_heads(self, m_outputs, only_last=False): if not only_last: @@ -194,14 +194,14 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l @torch.no_grad() def inference(self, m_outputs, only_last=False): def generate_frame(out_dict): - flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W")) + #flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W")) flow_up = Flow(out_dict["up_flow"], names=("B", "C", "H", "W")) - return flow_low, flow_up + return flow_up if only_last: return generate_frame(m_outputs[-1]) else: - [generate_frame(out_dict) for out_dict in m_outputs] + return [generate_frame(out_dict) for out_dict in m_outputs] class RAFT(RAFTBase): @@ -272,8 +272,10 @@ def __init__(self, dropout=0, **kwargs): # inference with torch.no_grad(): - flow = raft.forward(frame1, frame2)["up_flow"][-1] # keep only last stage flow estimation + m_outputs = raft.forward(frame1, frame2) # keep only last stage flow estimation + output = raft.inference(m_outputs) + + flow = output[-1] flow = padder.unpad(flow) # unpad to original image resolution - flow = raft.inference(flow) flow = flow.detach().cpu() flow.get_view().render() From c2c5b5f36054411f4b056d7af7afe5dba4b371fb Mon Sep 17 00:00:00 2001 From: Julien Salotti Date: Fri, 8 Oct 2021 10:59:44 +0200 Subject: [PATCH 4/5] fix disp name to flow --- alonet/raft/raft.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/alonet/raft/raft.py b/alonet/raft/raft.py index 1ddaf033..50bdafbd 100644 --- a/alonet/raft/raft.py +++ b/alonet/raft/raft.py @@ -46,8 +46,8 @@ def __init__( update_block, alternate_corr=False, weights: str = None, - corr_block = CorrBlock, - device: torch.device = torch.device("cpu") + corr_block=CorrBlock, + device: torch.device = torch.device("cpu"), ): super().__init__() self.fnet = fnet @@ -179,22 +179,19 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l for itr in range(iters): coords1 = coords1.detach() corr = corr_fn(coords1) # index correlation volume - disp = coords1 - coords0 - net, up_mask, delta_disp = self.update_block(net, inp, corr, disp) - coords1 = coords1 + delta_disp - m_outputs.append({ - "flow":coords1 - coords0, - "hidden_state":net, - "up_mask":up_mask, - "delta_disp":delta_disp - }) + flow = coords1 - coords0 + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + coords1 = coords1 + delta_flow + m_outputs.append( + {"flow": coords1 - coords0, "hidden_state": net, "up_mask": up_mask, "delta_flow": delta_flow} + ) return self.forward_heads(m_outputs, only_last=only_last) @torch.no_grad() def inference(self, m_outputs, only_last=False): def generate_frame(out_dict): - #flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W")) + # flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W")) flow_up = Flow(out_dict["up_flow"], names=("B", "C", "H", "W")) return flow_up @@ -202,7 +199,7 @@ def generate_frame(out_dict): return generate_frame(m_outputs[-1]) else: return [generate_frame(out_dict) for out_dict in m_outputs] - + class RAFT(RAFTBase): """ @@ -274,7 +271,7 @@ def __init__(self, dropout=0, **kwargs): with torch.no_grad(): m_outputs = raft.forward(frame1, frame2) # keep only last stage flow estimation output = raft.inference(m_outputs) - + flow = output[-1] flow = padder.unpad(flow) # unpad to original image resolution flow = flow.detach().cpu() From 56b54e4a8a0c3b0f14b6bfd3a1e5f00b84c073d6 Mon Sep 17 00:00:00 2001 From: Julien Salotti Date: Wed, 10 Nov 2021 14:21:27 +0100 Subject: [PATCH 5/5] better implementation for abstract attributes --- alonet/common/abstract_classes.py | 35 +++++++++++++++++++++++++++++++ alonet/raft/raft.py | 17 ++++++++------- 2 files changed, 45 insertions(+), 7 deletions(-) create mode 100644 alonet/common/abstract_classes.py diff --git a/alonet/common/abstract_classes.py b/alonet/common/abstract_classes.py new file mode 100644 index 00000000..969a7846 --- /dev/null +++ b/alonet/common/abstract_classes.py @@ -0,0 +1,35 @@ +""" +Inspired by https://stackoverflow.com/a/50381071/14647356 + +Tools to create abstract classes +""" + + +class AbstractAttribute: + pass + + +def abstract_attribute(obj=None): + if obj is None: + obj = AbstractAttribute() + obj.__is_abstract_attribute__ = True + return obj + + +def check_abstract_attribute_instanciation(cls): + abstract_attributes = { + name for name in dir(cls) if getattr(getattr(cls, name), "__is_abstract_attribute__", False) + } + if abstract_attributes: + raise NotImplementedError( + f"Can't instantiate abstract class {type(cls).__name__}." + f"The following abstract attributes shoud be instanciated: {', '.join(abstract_attributes)}" + ) + + +def super_new(abstract_cls, cls, *args, **kwargs): + __new__ = super(abstract_cls, cls).__new__ + if __new__ is object.__new__: + return __new__(cls) + else: + return __new__(cls, *args, **kwargs) diff --git a/alonet/raft/raft.py b/alonet/raft/raft.py index 50bdafbd..6d379cdb 100644 --- a/alonet/raft/raft.py +++ b/alonet/raft/raft.py @@ -8,6 +8,7 @@ from alonet.raft.corr import CorrBlock, AlternateCorrBlock from alonet.raft.update import BasicUpdateBlock from alonet.raft.extractor import BasicEncoder +from alonet.common.abstract_classes import abstract_attribute, check_abstract_attribute_instanciation, super_new from alonet.raft.utils.utils import coords_grid, upflow8 from aloscene import Flow, Frame @@ -32,19 +33,22 @@ class RAFTBase(nn.Module): """ # should be overriden in subclasses - hidden_dim = None - context_dim = None - corr_levels = None - corr_radius = None + hidden_dim = abstract_attribute() + context_dim = abstract_attribute() + corr_levels = abstract_attribute() + corr_radius = abstract_attribute() + out_plane = abstract_attribute() - out_plane = None + # checks that all abstract attribute are instanciated in child class + def __new__(cls, *args, **kwargs): + check_abstract_attribute_instanciation(cls) + return super_new(RAFTBase, cls, *args, **kwargs) def __init__( self, fnet, cnet, update_block, - alternate_corr=False, weights: str = None, corr_block=CorrBlock, device: torch.device = torch.device("cpu"), @@ -229,7 +233,6 @@ class RAFT(RAFTBase): context_dim = 128 corr_levels = 4 corr_radius = 4 - out_plane = 2 def __init__(self, dropout=0, **kwargs):