diff --git a/onnx2pytorch/convert/model.py b/onnx2pytorch/convert/model.py index b3b79eb..6f00ad1 100644 --- a/onnx2pytorch/convert/model.py +++ b/onnx2pytorch/convert/model.py @@ -202,7 +202,10 @@ def forward(self, *input_list, **input_dict): ] in_activations = [in_act for in_act in in_activations if in_act is not None] - + if node.op_type == "Pad": + # preprocess pad in case it is 8-d array or 6-d array + from onnx2pytorch.operations.pad import preprocess_pads + in_activations = preprocess_pads(in_activations) # store activations for next layer if isinstance(op, Loop): outputs = op((self,), activations, *in_activations) diff --git a/onnx2pytorch/operations/pad.py b/onnx2pytorch/operations/pad.py index e0c7fd8..6240fe8 100644 --- a/onnx2pytorch/operations/pad.py +++ b/onnx2pytorch/operations/pad.py @@ -3,6 +3,26 @@ from onnx2pytorch.operations.base import Operator +def preprocess_pads(in_activations): + """ + If pads is 8-d array for 4d input or pads is 6-d array for 3d input. + Convert pads from [b1,b2,...,e1,e2,...] to [b1,e1,b2,e2,...] + + """ + input = in_activations[0] + pads = list(in_activations[1]) + if len(pads)//2 == len(input.size()): + import torch + new_pads = [] + mid_idx = len(pads)//2 + pads.reverse() + for i in range(mid_idx, len(pads)): + new_pads.append(pads[i]) + new_pads.append(pads[i-mid_idx]) + in_activations[1] = torch.tensor(new_pads) + return in_activations + + class Pad(Operator): def __init__(self, mode="constant", padding=None): self.mode = mode