Skip to content

Commit 91a057e

Browse files
divyashreepathihallimattdangerw
authored andcommitted
Add VGG16 backbone (#1737)
* Agg Vgg16 backbone * update names * update tests * update test * add image classifier * incorporate review comments * Update test case * update backbone test * add image classifier * classifier cleanup * code reformat * add vgg16 image classifier * make vgg generic * update doc string * update docstring * add classifier test * update tests * update docstring * address review comments * code reformat * update the configs * address review comments * fix task saved model test * update init * code reformatted
1 parent ed732a1 commit 91a057e

File tree

8 files changed

+514
-14
lines changed

8 files changed

+514
-14
lines changed

keras_nlp/api/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
GPTNeoXCausalLMPreprocessor,
158158
)
159159
from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
160+
from keras_nlp.src.models.image_classifier import ImageClassifier
160161
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
161162
from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM
162163
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
@@ -230,6 +231,8 @@
230231
from keras_nlp.src.models.text_classifier_preprocessor import (
231232
TextClassifierPreprocessor,
232233
)
234+
from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone
235+
from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier
233236
from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone
234237
from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer
235238
from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import (
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
16+
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.task import Task
18+
19+
20+
@keras_nlp_export("keras_nlp.models.ImageClassifier")
21+
class ImageClassifier(Task):
22+
"""Base class for all image classification tasks.
23+
24+
`ImageClassifier` tasks wrap a `keras_nlp.models.Backbone` and
25+
a `keras_nlp.models.Preprocessor` to create a model that can be used for
26+
image classification. `ImageClassifier` tasks take an additional
27+
`num_classes` argument, controlling the number of predicted output classes.
28+
29+
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
30+
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
31+
32+
All `ImageClassifier` tasks include a `from_preset()` constructor which can be
33+
used to load a pre-trained config and weights.
34+
"""
35+
36+
def __init__(self, *args, **kwargs):
37+
super().__init__(*args, **kwargs)
38+
# Default compilation.
39+
self.compile()
40+
41+
def compile(
42+
self,
43+
optimizer="auto",
44+
loss="auto",
45+
*,
46+
metrics="auto",
47+
**kwargs,
48+
):
49+
"""Configures the `ImageClassifier` task for training.
50+
51+
The `ImageClassifier` task extends the default compilation signature of
52+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
53+
`metrics`. To override these defaults, pass any value
54+
to these arguments during compilation.
55+
56+
Args:
57+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
58+
instance. Defaults to `"auto"`, which uses the default optimizer
59+
for the given model and task. See `keras.Model.compile` and
60+
`keras.optimizers` for more info on possible `optimizer` values.
61+
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
62+
Defaults to `"auto"`, where a
63+
`keras.losses.SparseCategoricalCrossentropy` loss will be
64+
applied for the classification task. See
65+
`keras.Model.compile` and `keras.losses` for more info on
66+
possible `loss` values.
67+
metrics: `"auto"`, or a list of metrics to be evaluated by
68+
the model during training and testing. Defaults to `"auto"`,
69+
where a `keras.metrics.SparseCategoricalAccuracy` will be
70+
applied to track the accuracy of the model during training.
71+
See `keras.Model.compile` and `keras.metrics` for
72+
more info on possible `metrics` values.
73+
**kwargs: See `keras.Model.compile` for a full list of arguments
74+
supported by the compile method.
75+
"""
76+
if optimizer == "auto":
77+
optimizer = keras.optimizers.Adam(5e-5)
78+
if loss == "auto":
79+
activation = getattr(self, "activation", None)
80+
activation = keras.activations.get(activation)
81+
from_logits = activation != keras.activations.softmax
82+
loss = keras.losses.SparseCategoricalCrossentropy(from_logits)
83+
if metrics == "auto":
84+
metrics = [keras.metrics.SparseCategoricalAccuracy()]
85+
super().compile(
86+
optimizer=optimizer,
87+
loss=loss,
88+
metrics=metrics,
89+
**kwargs,
90+
)

keras_nlp/src/models/vgg/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
from keras import layers
16+
17+
from keras_nlp.src.api_export import keras_nlp_export
18+
from keras_nlp.src.models.backbone import Backbone
19+
20+
21+
@keras_nlp_export("keras_nlp.models.VGGBackbone")
22+
class VGGBackbone(Backbone):
23+
"""
24+
This class represents Keras Backbone of VGG model.
25+
26+
This class implements a VGG backbone as described in [Very Deep
27+
Convolutional Networks for Large-Scale Image Recognition](
28+
https://arxiv.org/abs/1409.1556)(ICLR 2015).
29+
30+
Args:
31+
stackwise_num_repeats: list of ints, number of repeated convolutional
32+
blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for
33+
VGG19 this is [2, 2, 4, 4, 4].
34+
stackwise_num_filters: list of ints, filter size for convolutional
35+
blocks per VGG block. For both VGG16 and VGG19 this is [
36+
64, 128, 256, 512, 512].
37+
include_rescaling: bool, whether to rescale the inputs. If set to
38+
True, inputs will be passed through a `Rescaling(1/255.0)` layer.
39+
input_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
40+
pooling: bool, Optional pooling mode for feature extraction
41+
when `include_top` is `False`.
42+
- `None` means that the output of the model will be
43+
the 4D tensor output of the
44+
last convolutional block.
45+
- `avg` means that global average pooling
46+
will be applied to the output of the
47+
last convolutional block, and thus
48+
the output of the model will be a 2D tensor.
49+
- `max` means that global max pooling will
50+
be applied.
51+
52+
Examples:
53+
```python
54+
input_data = np.ones((2, 224, 224, 3), dtype="float32")
55+
56+
# Pretrained VGG backbone.
57+
model = keras_nlp.models.VGGBackbone.from_preset("vgg16")
58+
model(input_data)
59+
60+
# Randomly initialized VGG backbone with a custom config.
61+
model = keras_nlp.models.VGGBackbone(
62+
stackwise_num_repeats = [2, 2, 3, 3, 3],
63+
stackwise_num_filters = [64, 128, 256, 512, 512],
64+
input_shape = (224, 224, 3),
65+
include_rescaling = False,
66+
pooling = "avg",
67+
)
68+
model(input_data)
69+
```
70+
"""
71+
72+
def __init__(
73+
self,
74+
stackwise_num_repeats,
75+
stackwise_num_filters,
76+
include_rescaling,
77+
input_image_shape=(224, 224, 3),
78+
pooling="avg",
79+
**kwargs,
80+
):
81+
82+
# === Functional Model ===
83+
img_input = keras.layers.Input(shape=input_image_shape)
84+
x = img_input
85+
86+
if include_rescaling:
87+
x = layers.Rescaling(scale=1 / 255.0)(x)
88+
for stack_index in range(len(stackwise_num_repeats) - 1):
89+
x = apply_vgg_block(
90+
x=x,
91+
num_layers=stackwise_num_repeats[stack_index],
92+
filters=stackwise_num_filters[stack_index],
93+
kernel_size=(3, 3),
94+
activation="relu",
95+
padding="same",
96+
max_pool=True,
97+
name=f"block{stack_index + 1}",
98+
)
99+
if pooling == "avg":
100+
x = layers.GlobalAveragePooling2D()(x)
101+
elif pooling == "max":
102+
x = layers.GlobalMaxPooling2D()(x)
103+
104+
super().__init__(inputs=img_input, outputs=x, **kwargs)
105+
106+
# === Config ===
107+
self.stackwise_num_repeats = stackwise_num_repeats
108+
self.stackwise_num_filters = stackwise_num_filters
109+
self.include_rescaling = include_rescaling
110+
self.input_image_shape = input_image_shape
111+
self.pooling = pooling
112+
113+
def get_config(self):
114+
return {
115+
"stackwise_num_repeats": self.stackwise_num_repeats,
116+
"stackwise_num_filters": self.stackwise_num_filters,
117+
"include_rescaling": self.include_rescaling,
118+
"input_image_shape": self.input_image_shape,
119+
"pooling": self.pooling,
120+
}
121+
122+
123+
def apply_vgg_block(
124+
x,
125+
num_layers,
126+
filters,
127+
kernel_size,
128+
activation,
129+
padding,
130+
max_pool,
131+
name,
132+
):
133+
"""
134+
Applies VGG block
135+
Args:
136+
x: Tensor, input tensor to pass through network
137+
num_layers: int, number of CNN layers in the block
138+
filters: int, filter size of each CNN layer in block
139+
kernel_size: int (or) tuple, kernel size for CNN layer in block
140+
activation: str (or) callable, activation function for each CNN layer in
141+
block
142+
padding: str (or) callable, padding function for each CNN layer in block
143+
max_pool: bool, whether to add MaxPooling2D layer at end of block
144+
name: str, name of the block
145+
146+
Returns:
147+
keras.KerasTensor
148+
"""
149+
for num in range(1, num_layers + 1):
150+
x = layers.Conv2D(
151+
filters,
152+
kernel_size,
153+
activation=activation,
154+
padding=padding,
155+
name=f"{name}_conv{num}",
156+
)(x)
157+
if max_pool:
158+
x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x)
159+
return x
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone
19+
from keras_nlp.src.tests.test_case import TestCase
20+
21+
22+
class VGGBackboneTest(TestCase):
23+
def setUp(self):
24+
self.init_kwargs = {
25+
"stackwise_num_repeats": [2, 3, 3],
26+
"stackwise_num_filters": [8, 64, 64],
27+
"input_image_shape": (16, 16, 3),
28+
"include_rescaling": False,
29+
"pooling": "avg",
30+
}
31+
self.input_data = np.ones((2, 16, 16, 3), dtype="float32")
32+
33+
def test_backbone_basics(self):
34+
self.run_backbone_test(
35+
cls=VGGBackbone,
36+
init_kwargs=self.init_kwargs,
37+
input_data=self.input_data,
38+
expected_output_shape=(2, 64),
39+
run_mixed_precision_check=False,
40+
)
41+
42+
@pytest.mark.large
43+
def test_saved_model(self):
44+
self.run_model_saving_test(
45+
cls=VGGBackbone,
46+
init_kwargs=self.init_kwargs,
47+
input_data=self.input_data,
48+
)

0 commit comments

Comments
 (0)