Skip to content

Commit 3122ea1

Browse files
authored
Respect strict=False when loading detection models (#5841)
* Convert weights only if `old_key` is in `state_dict` * Fix linter
1 parent 92eb12d commit 3122ea1

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

torchvision/models/detection/mask_rcnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ def _load_from_state_dict(
317317
for type in ["weight", "bias"]:
318318
old_key = f"{prefix}mask_fcn{i+1}.{type}"
319319
new_key = f"{prefix}{i}.0.{type}"
320-
state_dict[new_key] = state_dict.pop(old_key)
320+
if old_key in state_dict:
321+
state_dict[new_key] = state_dict.pop(old_key)
321322

322323
super()._load_from_state_dict(
323324
state_dict,

torchvision/models/detection/retinanet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def _v1_to_v2_weights(state_dict, prefix):
4545
for type in ["weight", "bias"]:
4646
old_key = f"{prefix}conv.{2*i}.{type}"
4747
new_key = f"{prefix}conv.{i}.0.{type}"
48-
state_dict[new_key] = state_dict.pop(old_key)
48+
if old_key in state_dict:
49+
state_dict[new_key] = state_dict.pop(old_key)
4950

5051

5152
def _default_anchorgen():

torchvision/models/detection/rpn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def _load_from_state_dict(
5656
for type in ["weight", "bias"]:
5757
old_key = f"{prefix}conv.{type}"
5858
new_key = f"{prefix}conv.0.0.{type}"
59-
state_dict[new_key] = state_dict.pop(old_key)
59+
if old_key in state_dict:
60+
state_dict[new_key] = state_dict.pop(old_key)
6061

6162
super()._load_from_state_dict(
6263
state_dict,

torchvision/ops/feature_pyramid_network.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def _load_from_state_dict(
128128
for type in ["weight", "bias"]:
129129
old_key = f"{prefix}{block}.{i}.{type}"
130130
new_key = f"{prefix}{block}.{i}.0.{type}"
131-
state_dict[new_key] = state_dict.pop(old_key)
131+
if old_key in state_dict:
132+
state_dict[new_key] = state_dict.pop(old_key)
132133

133134
super()._load_from_state_dict(
134135
state_dict,

0 commit comments

Comments
 (0)