Skip to content

Commit 9568ba6

Browse files
committed
Add conversion implementation
1 parent c6159fa commit 9568ba6

File tree

8 files changed

+342
-10
lines changed

8 files changed

+342
-10
lines changed

keras_nlp/src/models/backbone.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from keras_nlp.src.utils.preset_utils import save_metadata
3131
from keras_nlp.src.utils.preset_utils import save_serialized_object
3232
from keras_nlp.src.utils.python_utils import classproperty
33+
from keras_nlp.src.utils.timm.convert import load_timm_backbone
3334
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone
3435

3536

@@ -204,6 +205,8 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
204205

205206
if format == "transformers":
206207
return load_transformers_backbone(cls, preset, load_weights)
208+
elif format == "timm":
209+
return load_timm_backbone(cls, preset, load_weights, **kwargs)
207210

208211
preset_cls = check_config_class(preset)
209212
if not issubclass(preset_cls, cls):

keras_nlp/src/models/resnet/resnet_backbone.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
4949
use_pre_activation: boolean. Whether to use pre-activation or not.
5050
`True` for ResNetV2, `False` for ResNet.
5151
include_rescaling: boolean. If `True`, rescale the input using
52-
`Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to
53-
`True`.
52+
`Rescaling` and `Normalization` layers. If `False`, do nothing.
53+
Defaults to `True`.
5454
input_image_shape: tuple. The input shape without the batch size.
5555
Defaults to `(None, None, 3)`.
5656
pooling: `None` or str. Pooling mode for feature extraction. Defaults
@@ -139,6 +139,12 @@ def __init__(
139139
image_input = layers.Input(shape=input_image_shape)
140140
if include_rescaling:
141141
x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input)
142+
x = layers.Normalization(
143+
mean=(0.485, 0.456, 0.406),
144+
variance=(0.229**2, 0.224**2, 0.225**2),
145+
dtype=dtype,
146+
name="normalization",
147+
)(x)
142148
else:
143149
x = image_input
144150

@@ -327,13 +333,14 @@ def apply_basic_block(
327333
dtype=dtype,
328334
name=f"{name}_0_conv",
329335
)(x)
330-
shortcut = layers.BatchNormalization(
331-
axis=bn_axis,
332-
epsilon=1e-5,
333-
momentum=0.9,
334-
dtype=dtype,
335-
name=f"{name}_0_bn",
336-
)(shortcut)
336+
if not use_pre_activation:
337+
shortcut = layers.BatchNormalization(
338+
axis=bn_axis,
339+
epsilon=1e-5,
340+
momentum=0.9,
341+
dtype=dtype,
342+
name=f"{name}_0_bn",
343+
)(shortcut)
337344
else:
338345
shortcut = x
339346

@@ -363,6 +370,7 @@ def apply_basic_block(
363370
name=f"{name}_1_bn",
364371
)(x)
365372
x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
373+
366374
x = layers.Conv2D(
367375
filters,
368376
kernel_size,
@@ -373,7 +381,6 @@ def apply_basic_block(
373381
dtype=dtype,
374382
name=f"{name}_2_conv",
375383
)(x)
376-
377384
if not use_pre_activation:
378385
x = layers.BatchNormalization(
379386
axis=bn_axis,

keras_nlp/src/utils/preset_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
KAGGLE_PREFIX = "kaggle://"
5151
GS_PREFIX = "gs://"
5252
HF_PREFIX = "hf://"
53+
TIMM_PREFIX = "hf://timm"
5354

5455
KAGGLE_SCHEME = "kaggle"
5556
GS_SCHEME = "gs"
@@ -544,6 +545,8 @@ def check_format(preset):
544545
if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
545546
preset, SAFETENSOR_CONFIG_FILE
546547
):
548+
if TIMM_PREFIX in preset:
549+
return "timm"
547550
return "transformers"
548551

549552
if not check_file_exists(preset, METADATA_FILE):

keras_nlp/src/utils/timm/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

keras_nlp/src/utils/timm/convert.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Convert timm models to KerasNLP."""
15+
16+
from keras_nlp.src.utils.timm.convert_resnet import load_resnet_backbone
17+
18+
19+
def load_timm_backbone(cls, preset, load_weights, **kwargs):
20+
"""Load a timm model config and weights as a KerasNLP backbone.
21+
22+
Args:
23+
cls (class): Keras model class.
24+
preset (str): Preset configuration name.
25+
load_weights (bool): Whether to load the weights.
26+
27+
Returns:
28+
backbone: Initialized Keras model backbone.
29+
"""
30+
if cls is None:
31+
raise ValueError("Backbone class is None")
32+
if cls.__name__ == "ResNetBackbone":
33+
return load_resnet_backbone(cls, preset, load_weights, **kwargs)
34+
raise ValueError(
35+
f"{cls} has not been ported from the Hugging Face format yet. "
36+
"Please check Hugging Face Hub for the Keras model. "
37+
)
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
16+
from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE
17+
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
18+
from keras_nlp.src.utils.preset_utils import load_config
19+
from keras_nlp.src.utils.timm.safetensor_utils import SafetensorLoader
20+
21+
22+
def convert_backbone_config(timm_config):
23+
timm_architecture = timm_config["architecture"]
24+
25+
if "resnetv2_" in timm_architecture:
26+
use_pre_activation = True
27+
else:
28+
use_pre_activation = False
29+
30+
if timm_architecture == "resnet18":
31+
stackwise_num_blocks = [2, 2, 2, 2]
32+
block_type = "basic_block"
33+
elif timm_architecture == "resnet26":
34+
stackwise_num_blocks = [2, 2, 2, 2]
35+
block_type = "bottleneck_block"
36+
elif timm_architecture == "resnet34":
37+
stackwise_num_blocks = [3, 4, 6, 3]
38+
block_type = "basic_block"
39+
elif timm_architecture in ("resnet50", "resnetv2_50"):
40+
stackwise_num_blocks = [3, 4, 6, 3]
41+
block_type = "bottleneck_block"
42+
elif timm_architecture in ("resnet101", "resnetv2_101"):
43+
stackwise_num_blocks = [3, 4, 23, 3]
44+
block_type = "bottleneck_block"
45+
elif timm_architecture in ("resnet152", "resnetv2_152"):
46+
stackwise_num_blocks = [3, 8, 36, 3]
47+
block_type = "bottleneck_block"
48+
else:
49+
raise ValueError(
50+
f"Currently, the architecture {timm_architecture} is not supported."
51+
)
52+
53+
return dict(
54+
stackwise_num_filters=[64, 128, 256, 512],
55+
stackwise_num_blocks=stackwise_num_blocks,
56+
stackwise_num_strides=[1, 2, 2, 2],
57+
block_type=block_type,
58+
use_pre_activation=use_pre_activation,
59+
)
60+
61+
62+
def convert_weights(backbone, loader, timm_config):
63+
def transpose_conv2d(x, shape):
64+
return np.transpose(x, (2, 3, 1, 0))
65+
66+
def port_conv2d(keras_layer_name, hf_weight_prefix):
67+
loader.port_weight(
68+
backbone.get_layer(keras_layer_name).kernel,
69+
hf_weight_key=f"{hf_weight_prefix}.weight",
70+
hook_fn=transpose_conv2d,
71+
)
72+
73+
def port_batch_normalization(keras_layer_name, hf_weight_prefix):
74+
loader.port_weight(
75+
backbone.get_layer(keras_layer_name).gamma,
76+
hf_weight_key=f"{hf_weight_prefix}.weight",
77+
)
78+
loader.port_weight(
79+
backbone.get_layer(keras_layer_name).beta,
80+
hf_weight_key=f"{hf_weight_prefix}.bias",
81+
)
82+
loader.port_weight(
83+
backbone.get_layer(keras_layer_name).moving_mean,
84+
hf_weight_key=f"{hf_weight_prefix}.running_mean",
85+
)
86+
loader.port_weight(
87+
backbone.get_layer(keras_layer_name).moving_variance,
88+
hf_weight_key=f"{hf_weight_prefix}.running_var",
89+
)
90+
91+
version = "v1" if not backbone.use_pre_activation else "v2"
92+
block_type = backbone.block_type
93+
94+
# Stem
95+
if version == "v1":
96+
port_conv2d("conv1_conv", "conv1")
97+
port_batch_normalization("conv1_bn", "bn1")
98+
else:
99+
port_conv2d("conv1_conv", "stem.conv")
100+
101+
# Stages
102+
num_stacks = len(backbone.stackwise_num_filters)
103+
for stack_index in range(num_stacks):
104+
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
105+
if version == "v1":
106+
keras_name = f"v1_stack{stack_index}_block{block_idx}"
107+
hf_name = f"layer{stack_index+1}.{block_idx}"
108+
else:
109+
keras_name = f"v2_stack{stack_index}_block{block_idx}"
110+
hf_name = f"stages.{stack_index}.blocks.{block_idx}"
111+
112+
if version == "v1":
113+
if block_idx == 0 and (
114+
block_type == "bottleneck_block" or stack_index > 0
115+
):
116+
port_conv2d(
117+
f"{keras_name}_0_conv", f"{hf_name}.downsample.0"
118+
)
119+
port_batch_normalization(
120+
f"{keras_name}_0_bn", f"{hf_name}.downsample.1"
121+
)
122+
port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
123+
port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1")
124+
port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
125+
port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2")
126+
if block_type == "bottleneck_block":
127+
port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
128+
port_batch_normalization(
129+
f"{keras_name}_3_bn", f"{hf_name}.bn3"
130+
)
131+
else:
132+
if block_idx == 0 and (
133+
block_type == "bottleneck_block" or stack_index > 0
134+
):
135+
port_conv2d(
136+
f"{keras_name}_0_conv", f"{hf_name}.downsample.conv"
137+
)
138+
port_batch_normalization(
139+
f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1"
140+
)
141+
port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
142+
port_batch_normalization(
143+
f"{keras_name}_1_bn", f"{hf_name}.norm2"
144+
)
145+
port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
146+
if block_type == "bottleneck_block":
147+
port_batch_normalization(
148+
f"{keras_name}_2_bn", f"{hf_name}.norm3"
149+
)
150+
port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
151+
152+
# Post
153+
if version == "v2":
154+
port_batch_normalization("post_bn", "norm")
155+
156+
# Rebuild normalization layer with pretrained mean & std
157+
mean = timm_config["pretrained_cfg"]["mean"]
158+
std = timm_config["pretrained_cfg"]["std"]
159+
normalization_layer = backbone.get_layer("normalization")
160+
normalization_layer.input_mean = mean
161+
normalization_layer.input_variance = [s**2 for s in std]
162+
normalization_layer.build(normalization_layer._build_input_shape)
163+
164+
165+
def load_resnet_backbone(cls, preset, load_weights, **kwargs):
166+
timm_config = load_config(preset, HF_CONFIG_FILE)
167+
keras_config = convert_backbone_config(timm_config)
168+
backbone = cls(**keras_config, **kwargs)
169+
if load_weights:
170+
jax_memory_cleanup(backbone)
171+
with SafetensorLoader(preset) as loader:
172+
convert_weights(backbone, loader, timm_config)
173+
return backbone
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
from keras import ops
16+
17+
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
18+
from keras_nlp.src.tests.test_case import TestCase
19+
20+
21+
class TimmResNetBackboneTest(TestCase):
22+
@pytest.mark.large
23+
def test_convert_resnet18_preset(self):
24+
model = ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k")
25+
outputs = model.predict(ops.ones((1, 224, 224, 3)))
26+
self.assertEqual(outputs.shape, (1, 512))
27+
28+
# TODO: compare numerics with timm model

0 commit comments

Comments
 (0)