Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions alonet/common/abstract_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Inspired by https://stackoverflow.com/a/50381071/14647356

Tools to create abstract classes
"""


class AbstractAttribute:
pass


def abstract_attribute(obj=None):
if obj is None:
obj = AbstractAttribute()
obj.__is_abstract_attribute__ = True
return obj


def check_abstract_attribute_instanciation(cls):
abstract_attributes = {
name for name in dir(cls) if getattr(getattr(cls, name), "__is_abstract_attribute__", False)
}
if abstract_attributes:
raise NotImplementedError(
f"Can't instantiate abstract class {type(cls).__name__}."
f"The following abstract attributes shoud be instanciated: {', '.join(abstract_attributes)}"
)


def super_new(abstract_cls, cls, *args, **kwargs):
__new__ = super(abstract_cls, cls).__new__
if __new__ is object.__new__:
return __new__(cls)
else:
return __new__(cls, *args, **kwargs)
20 changes: 10 additions & 10 deletions alonet/raft/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,33 @@ def __init__(self):

# loss from RAFT implementation
@staticmethod
def sequence_loss(flow_preds, flow_gt, valid=None, gamma=0.8, max_flow=400, compute_per_iter=False):
def sequence_loss(m_outputs, flow_gt, valid=None, gamma=0.8, max_flow=400, compute_per_iter=False):
"""Loss function defined over sequence of flow predictions"""
n_predictions = len(flow_preds)
n_predictions = len(m_outputs)
flow_loss = 0.0

# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt ** 2, dim=1).sqrt()
mag = torch.sum(flow_gt ** 2, dim=1, keepdim=True).sqrt()
if valid is None:
valid = torch.ones_like(mag, dtype=torch.bool)
else:
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
m_dict = m_outputs[i]
i_weight = gamma ** (n_predictions - i - 1)
i_loss = (flow_preds[i] - flow_gt).abs()
i_loss = (m_dict["up_flow"] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()

if compute_per_iter:
epe_per_iter = []
for flow_p in flow_preds:
epe = torch.sum((flow_p - flow_gt) ** 2, dim=1).sqrt()
for i in range(n_predictions):
m_dict = m_outputs[i]
epe = torch.sum((m_dict["up_flow"] - flow_gt) ** 2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
epe_per_iter.append(epe)
else:
epe_per_iter = None

epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
epe = torch.sum((m_outputs[-1]["up_flow"] - flow_gt) ** 2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]

metrics = {
Expand All @@ -50,7 +51,6 @@ def sequence_loss(flow_preds, flow_gt, valid=None, gamma=0.8, max_flow=400, comp

def forward(self, m_outputs, frame1, use_valid=True, compute_per_iter=False):
assert isinstance(frame1, aloscene.Frame)
flow_preds = m_outputs
flow_gt = [f.batch() for f in frame1.flow["flow_forward"]]
flow_gt = torch.cat(flow_gt, dim=0)
# occlusion mask -- not used in raft original repo
Expand All @@ -65,6 +65,6 @@ def forward(self, m_outputs, frame1, use_valid=True, compute_per_iter=False):
else:
valid = None
flow_loss, metrics, epe_per_iter = RAFTCriterion.sequence_loss(
flow_preds, flow_gt, valid, compute_per_iter=compute_per_iter
m_outputs, flow_gt, valid, compute_per_iter=compute_per_iter
)
return flow_loss, metrics, epe_per_iter
110 changes: 56 additions & 54 deletions alonet/raft/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from alonet.raft.corr import CorrBlock, AlternateCorrBlock
from alonet.raft.update import BasicUpdateBlock
from alonet.raft.extractor import BasicEncoder
from alonet.common.abstract_classes import abstract_attribute, check_abstract_attribute_instanciation, super_new
from alonet.raft.utils.utils import coords_grid, upflow8
from aloscene import Flow, Frame

Expand All @@ -32,25 +33,31 @@ class RAFTBase(nn.Module):
"""

# should be overriden in subclasses
hidden_dim = None
context_dim = None
corr_levels = None
corr_radius = None
hidden_dim = abstract_attribute()
context_dim = abstract_attribute()
corr_levels = abstract_attribute()
corr_radius = abstract_attribute()
out_plane = abstract_attribute()

# checks that all abstract attribute are instanciated in child class
def __new__(cls, *args, **kwargs):
check_abstract_attribute_instanciation(cls)
return super_new(RAFTBase, cls, *args, **kwargs)

def __init__(
self,
fnet,
cnet,
update_block,
alternate_corr=False,
weights: str = None,
corr_block=CorrBlock,
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.fnet = fnet
self.cnet = cnet
self.update_block = update_block
self.alternate_corr = alternate_corr
self.corr_block = corr_block

if weights is not None:
weights_from_original_repo = ["raft-things", "raft-chairs", "raft-small", "raft-kitti", "raft-sintel"]
Expand Down Expand Up @@ -83,7 +90,7 @@ def build_update_block(self, update_cls=BasicUpdateBlock):
"""
Build RAFT Update Block
"""
return update_cls(self.corr_levels, self.corr_radius, hidden_dim=self.hdim)
return update_cls(self.corr_levels, self.corr_radius, hidden_dim=self.hdim, out_planes=self.out_plane)

def freeze_bn(self):
for m in self.modules():
Expand All @@ -101,16 +108,28 @@ def initialize_flow(self, img):

def upsample_flow(self, flow, mask):
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
if mask is None:
return upflow8(flow)
else:
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)

up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, self.out_plane, 9, 1, 1, H, W)

up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, self.out_plane, 8 * H, 8 * W)

up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
def forward_heads(self, m_outputs, only_last=False):
if not only_last:
for out_dict in m_outputs:
out_dict["up_flow"] = self.upsample_flow(out_dict["flow"], out_dict["up_mask"])

else:
m_outputs[-1]["up_flow"] = self.upsample_flow(m_outputs[-1]["flow"], m_outputs[-1]["up_mask"])
return m_outputs

def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_last=False):
"""Estimate optical flow between pair of frames
Expand Down Expand Up @@ -140,26 +159,17 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l
frame1 = frame1.as_tensor()
frame2 = frame2.as_tensor()

# frame1 = frame1.contiguous()
# frame2 = frame2.contiguous()

hdim = self.hidden_dim
cdim = self.context_dim

# run the feature network

fmap1, fmap2 = self.fnet([frame1, frame2])

fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.corr_radius)

corr_fn = self.corr_block(fmap1, fmap2, radius=self.corr_radius)

# run the context network
cnet = self.cnet(frame1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net, inp = torch.split(cnet, [self.hdim, self.cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)

Expand All @@ -168,42 +178,31 @@ def forward(self, frame1: Frame, frame2: Frame, iters=12, flow_init=None, only_l
if flow_init is not None:
coords1 = coords1 + flow_init

flow_predictions = []
m_outputs = list()

for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume

flow = coords1 - coords0
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)

# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
m_outputs.append(
{"flow": coords1 - coords0, "hidden_state": net, "up_mask": up_mask, "delta_flow": delta_flow}
)

# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)

flow_predictions.append(flow_up)

if only_last:
flow_low = coords1 - coords0
return flow_low, flow_up
else:
return flow_predictions
return self.forward_heads(m_outputs, only_last=only_last)

@torch.no_grad()
def inference(self, forward_out, only_last=False):
def inference(self, m_outputs, only_last=False):
def generate_frame(out_dict):
# flow_low = Flow(out_dict["flow"], names=("B", "C", "H", "W"))
flow_up = Flow(out_dict["up_flow"], names=("B", "C", "H", "W"))
return flow_up

if only_last:
flow_low, flow_up = forward_out
flow_low = Flow(flow_low, names=("B", "C", "H", "W"))
flow_up = Flow(flow_up, names=("B", "C", "H", "W"))
return flow_low, flow_up
elif isinstance(forward_out, list):
return [Flow(flow, names=("B", "C", "H", "W")) for flow in forward_out]
return generate_frame(m_outputs[-1])
else:
return Flow(forward_out, names=("B", "C", "H", "W"))
return [generate_frame(out_dict) for out_dict in m_outputs]


class RAFT(RAFTBase):
Expand Down Expand Up @@ -234,6 +233,7 @@ class RAFT(RAFTBase):
context_dim = 128
corr_levels = 4
corr_radius = 4
out_plane = 2

def __init__(self, dropout=0, **kwargs):
self.dropout = dropout
Expand Down Expand Up @@ -272,8 +272,10 @@ def __init__(self, dropout=0, **kwargs):

# inference
with torch.no_grad():
flow = raft.forward(frame1, frame2)[-1] # keep only last stage flow estimation
m_outputs = raft.forward(frame1, frame2) # keep only last stage flow estimation
output = raft.inference(m_outputs)

flow = output[-1]
flow = padder.unpad(flow) # unpad to original image resolution
flow = raft.inference(flow)
flow = flow.detach().cpu()
flow.get_view().render()
4 changes: 2 additions & 2 deletions alonet/raft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def validation_step(self, frames, batch_idx, dataloader_idx=None):
def build_criterion(self):
return RAFTCriterion()

def build_model(self, alternate_corr=False, weights=None, device="cpu", dropout=0):
return alonet.raft.RAFT(alternate_corr=alternate_corr, weights=weights, device=device, dropout=dropout)
def build_model(self, weights=None, device="cpu", dropout=0):
return alonet.raft.RAFT(weights=weights, device=device, dropout=dropout)

def configure_optimizers(self, lr=4e-4, weight_decay=1e-4, epsilon=1e-8, numsteps=100000):
params = self.model.parameters()
Expand Down
35 changes: 18 additions & 17 deletions alonet/raft/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@


class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
def __init__(self, input_dim=128, hidden_dim=256, out_planes=2):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, out_planes, 3, padding=1)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
Expand All @@ -27,7 +27,7 @@ def forward(self, h, x):
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))

h = (1 - z) * h + z * q
return h

Expand Down Expand Up @@ -62,11 +62,12 @@ def forward(self, h, x):


class SmallMotionEncoder(nn.Module):
def __init__(self, corr_levels, corr_radius):
def __init__(self, corr_levels, corr_radius, out_planes=2):
super(SmallMotionEncoder, self).__init__()
cor_planes = corr_levels * (2 * corr_radius + 1) ** 2
cor_planes = corr_levels * (2 * corr_radius + 1) ** out_planes

self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf1 = nn.Conv2d(out_planes, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)

Expand All @@ -80,14 +81,14 @@ def forward(self, flow, corr):


class BasicMotionEncoder(nn.Module):
def __init__(self, corr_levels, corr_radius):
def __init__(self, corr_levels, corr_radius, out_planes=2):
super(BasicMotionEncoder, self).__init__()
cor_planes = corr_levels * (2 * corr_radius + 1) ** 2
cor_planes = corr_levels * (2 * corr_radius + 1) ** out_planes
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf1 = nn.Conv2d(out_planes, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - out_planes, 3, padding=1)

def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
Expand All @@ -101,11 +102,11 @@ def forward(self, flow, corr):


class SmallUpdateBlock(nn.Module):
def __init__(self, corr_levels, corr_radius, hidden_dim=96):
def __init__(self, corr_levels, corr_radius, hidden_dim=96, out_planes=2):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(corr_levels, corr_radius)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
self.encoder = SmallMotionEncoder(corr_levels, corr_radius, out_planes=out_planes)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=hidden_dim + 49)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128, out_planes=out_planes)

def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
Expand All @@ -117,11 +118,11 @@ def forward(self, net, inp, corr, flow):


class BasicUpdateBlock(nn.Module):
def __init__(self, corr_levels, corr_radius, hidden_dim=128, input_dim=128):
def __init__(self, corr_levels, corr_radius, hidden_dim=128, input_dim=128, out_planes=2):
super(BasicUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder(corr_levels, corr_radius)
self.encoder = BasicMotionEncoder(corr_levels, corr_radius, out_planes=out_planes)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256, out_planes=out_planes)

self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 64 * 9, 1, padding=0)
Expand Down