From 103452268515b77d385dc158de59b74a2eb44964 Mon Sep 17 00:00:00 2001 From: Gowtham Paimagam Date: Mon, 19 Aug 2024 20:00:19 +0200 Subject: [PATCH 1/4] Add ResNet_vd to ResNet backbone --- .../src/models/resnet/resnet_backbone.py | 148 +++++++++++++++--- .../src/models/resnet/resnet_backbone_test.py | 33 ++-- 2 files changed, 148 insertions(+), 33 deletions(-) diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 0f4d7c139a..a7ee625928 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone): This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( CVPR 2016), [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An improved training procedure in timm](https://arxiv.org/abs/2110.00476)( - NeurIPS 2021 Workshop). + NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with + Convolutional Neural Networks](https://arxiv.org/abs/1812.01187). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -37,6 +38,12 @@ class ResNetBackbone(FeaturePyramidBackbone): the batch normalization and ReLU activation are applied after the convolution layers. + ResNetVd introduces two key modifications to the standard ResNet. First, + the initial convolutional layer is replaced by a series of three + successive convolutional layers. Second, shortcut connections use an + additional pooling operation rather than performing downsampling within + the convolutional layers themselves. + Note that `ResNetBackbone` expects the inputs to be images with a value range of `[0, 255]` when `include_rescaling=True`. @@ -51,6 +58,7 @@ class ResNetBackbone(FeaturePyramidBackbone): Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. + use_vd_pooling: boolean. Whether to use ResNetVd enhancements. include_rescaling: boolean. If `True`, rescale the input using `Rescaling` and `Normalization` layers. If `False`, do nothing. Defaults to `True`. @@ -106,6 +114,7 @@ def __init__( stackwise_num_strides, block_type, use_pre_activation=False, + use_vd_pooling=False, include_rescaling=True, input_image_shape=(None, None, 3), pooling="avg", @@ -133,7 +142,12 @@ def __init__( '`block_type` must be either `"basic_block"` or ' f'`"bottleneck_block"`. Received block_type={block_type}.' ) - version = "v1" if not use_pre_activation else "v2" + if use_vd_pooling: + version = "vd" + elif use_pre_activation: + version = "v2" + else: + version = "v1" data_format = standardize_data_format(data_format) bn_axis = -1 if data_format == "channels_last" else 1 num_stacks = len(stackwise_num_filters) @@ -155,21 +169,21 @@ def __init__( # The padding between torch and tensorflow/jax differs when `strides>1`. # Therefore, we need to manually pad the tensor. x = layers.ZeroPadding2D( - 3, + 1 if use_vd_pooling else 3, data_format=data_format, dtype=dtype, name="conv1_pad", )(x) - x = layers.Conv2D( - 64, - 7, - strides=2, - data_format=data_format, - use_bias=False, - dtype=dtype, - name="conv1_conv", - )(x) - if not use_pre_activation: + if use_vd_pooling: + x = layers.Conv2D( + 32, + 3, + strides=2, + data_format=data_format, + use_bias=False, + dtype=dtype, + name="conv1_conv", + )(x) x = layers.BatchNormalization( axis=bn_axis, epsilon=1e-5, @@ -178,6 +192,57 @@ def __init__( name="conv1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) + x = layers.Conv2D( + 32, + 3, + strides=1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name="conv2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="conv2_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name="conv2_relu")(x) + x = layers.Conv2D( + 64, + 3, + strides=1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name="conv3_conv", + )(x) + else: + x = layers.Conv2D( + 64, + 7, + strides=2, + data_format=data_format, + use_bias=False, + dtype=dtype, + name="conv1_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="conv3_bn" if use_vd_pooling else "conv1_bn", + )(x) + x = layers.Activation( + "relu", + dtype=dtype, + name="conv3_relu" if use_vd_pooling else "conv1_relu", + )(x) if use_pre_activation: # A workaround for ResNetV2: we need -inf padding to prevent zeros @@ -210,8 +275,11 @@ def __init__( stride=stackwise_num_strides[stack_index], block_type=block_type, use_pre_activation=use_pre_activation, + use_vd_pooling=use_vd_pooling, first_shortcut=( - block_type == "bottleneck_block" or stack_index > 0 + block_type == "bottleneck_block" + or stack_index > 0 + or use_vd_pooling ), data_format=data_format, dtype=dtype, @@ -253,6 +321,7 @@ def __init__( self.stackwise_num_strides = stackwise_num_strides self.block_type = block_type self.use_pre_activation = use_pre_activation + self.use_vd_pooling = use_vd_pooling self.include_rescaling = include_rescaling self.input_image_shape = input_image_shape self.pooling = pooling @@ -267,6 +336,7 @@ def get_config(self): "stackwise_num_strides": self.stackwise_num_strides, "block_type": self.block_type, "use_pre_activation": self.use_pre_activation, + "use_vd_pooling": self.use_vd_pooling, "include_rescaling": self.include_rescaling, "input_image_shape": self.input_image_shape, "pooling": self.pooling, @@ -282,6 +352,7 @@ def apply_basic_block( stride=1, conv_shortcut=False, use_pre_activation=False, + use_vd_pooling=False, data_format=None, dtype=None, name=None, @@ -299,6 +370,7 @@ def apply_basic_block( `False`. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + use_vd_pooling: boolean. Whether to use ResNetVd enhancements. data_format: `None` or str. the ordering of the dimensions in the inputs. Can be `"channels_last"` (`(batch_size, height, width, channels)`) or`"channels_first"` @@ -327,16 +399,27 @@ def apply_basic_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + elif use_vd_pooling and stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x shortcut = layers.Conv2D( filters, 1, - strides=stride, + strides=1 if use_vd_pooling else stride, data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x) + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -407,6 +490,7 @@ def apply_bottleneck_block( stride=1, conv_shortcut=False, use_pre_activation=False, + use_vd_pooling=False, data_format=None, dtype=None, name=None, @@ -424,6 +508,7 @@ def apply_bottleneck_block( `False`. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + use_vd_pooling: boolean. Whether to use ResNetVd enhancements. data_format: `None` or str. the ordering of the dimensions in the inputs. Can be `"channels_last"` (`(batch_size, height, width, channels)`) or`"channels_first"` @@ -452,16 +537,27 @@ def apply_bottleneck_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + elif use_vd_pooling and stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x shortcut = layers.Conv2D( 4 * filters, 1, - strides=stride, + strides=1 if use_vd_pooling else stride, data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x) + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -548,6 +644,7 @@ def apply_stack( stride, block_type, use_pre_activation, + use_vd_pooling=False, first_shortcut=True, data_format=None, dtype=None, @@ -565,6 +662,7 @@ def apply_stack( Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet and ResNeXt. + use_vd_pooling: boolean. Whether to use ResNetVd enhancements. first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, use an identity or pooling shortcut based on the stride. Defaults to `True`. @@ -580,7 +678,12 @@ def apply_stack( Output tensor for the stacked blocks. """ if name is None: - version = "v1" if not use_pre_activation else "v2" + if use_vd_pooling: + version = "vd" + elif use_pre_activation: + version = "v2" + else: + version = "v1" name = f"{version}_stack" if block_type == "basic_block": @@ -605,6 +708,7 @@ def apply_stack( stride=stride, conv_shortcut=conv_shortcut, use_pre_activation=use_pre_activation, + use_vd_pooling=use_vd_pooling, data_format=data_format, dtype=dtype, name=f"{name}_block{str(i)}", diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 6d3f774559..5bcd79f22f 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -34,15 +34,23 @@ def setUp(self): self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) @parameterized.named_parameters( - ("v1_basic", False, "basic_block"), - ("v1_bottleneck", False, "bottleneck_block"), - ("v2_basic", True, "basic_block"), - ("v2_bottleneck", True, "bottleneck_block"), + ("v1_basic", False, False, "basic_block"), + ("v1_bottleneck", False, False, "bottleneck_block"), + ("v2_basic", True, False, "basic_block"), + ("v2_bottleneck", True, False, "bottleneck_block"), + ("vd_basic", False, True, "basic_block"), + ("vd_bottleneck", False, True, "bottleneck_block"), ) - def test_backbone_basics(self, use_pre_activation, block_type): + def test_backbone_basics( + self, use_pre_activation, use_vd_pooling, block_type + ): init_kwargs = self.init_kwargs.copy() init_kwargs.update( - {"block_type": block_type, "use_pre_activation": use_pre_activation} + { + "block_type": block_type, + "use_pre_activation": use_pre_activation, + "use_vd_pooling": use_vd_pooling, + } ) self.run_vision_backbone_test( cls=ResNetBackbone, @@ -72,18 +80,21 @@ def test_pyramid_output_format(self): self.assertEqual(tuple(v.shape[:3]), (2, size, size)) @parameterized.named_parameters( - ("v1_basic", False, "basic_block"), - ("v1_bottleneck", False, "bottleneck_block"), - ("v2_basic", True, "basic_block"), - ("v2_bottleneck", True, "bottleneck_block"), + ("v1_basic", False, False, "basic_block"), + ("v1_bottleneck", False, False, "bottleneck_block"), + ("v2_basic", True, False, "basic_block"), + ("v2_bottleneck", True, False, "bottleneck_block"), + ("vd_basic", False, True, "basic_block"), + ("vd_bottleneck", False, True, "bottleneck_block"), ) @pytest.mark.large - def test_saved_model(self, use_pre_activation, block_type): + def test_saved_model(self, use_pre_activation, use_vd_pooling, block_type): init_kwargs = self.init_kwargs.copy() init_kwargs.update( { "block_type": block_type, "use_pre_activation": use_pre_activation, + "use_vd_pooling": use_vd_pooling, "input_image_shape": (None, None, 3), } ) From 0fe8acc434daf8560199c0a28eaeab7f1acd030f Mon Sep 17 00:00:00 2001 From: Gowtham Paimagam Date: Sat, 24 Aug 2024 18:56:57 +0200 Subject: [PATCH 2/4] Addressed requested parameter changes --- .../src/models/resnet/resnet_backbone.py | 400 ++++++++++++++---- 1 file changed, 315 insertions(+), 85 deletions(-) diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index a7ee625928..1f57770619 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -58,7 +58,6 @@ class ResNetBackbone(FeaturePyramidBackbone): Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. - use_vd_pooling: boolean. Whether to use ResNetVd enhancements. include_rescaling: boolean. If `True`, rescale the input using `Rescaling` and `Normalization` layers. If `False`, do nothing. Defaults to `True`. @@ -109,12 +108,13 @@ class ResNetBackbone(FeaturePyramidBackbone): def __init__( self, + input_conv_filters, + input_conv_kernel_sizes, stackwise_num_filters, stackwise_num_blocks, stackwise_num_strides, block_type, use_pre_activation=False, - use_vd_pooling=False, include_rescaling=True, input_image_shape=(None, None, 3), pooling="avg", @@ -137,19 +137,15 @@ def __init__( "The first element of `stackwise_num_filters` must be 64. " f"Received: stackwise_num_filters={stackwise_num_filters}" ) - if block_type not in ("basic_block", "bottleneck_block"): + if block_type not in ("basic_block", "bottleneck_block", "basic_block_vd", "bottleneck_block_vd"): raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) - if use_vd_pooling: - version = "vd" - elif use_pre_activation: - version = "v2" - else: - version = "v1" data_format = standardize_data_format(data_format) bn_axis = -1 if data_format == "channels_last" else 1 + num_input_convs = len(input_conv_filters) num_stacks = len(stackwise_num_filters) # === Functional Model === @@ -169,79 +165,53 @@ def __init__( # The padding between torch and tensorflow/jax differs when `strides>1`. # Therefore, we need to manually pad the tensor. x = layers.ZeroPadding2D( - 1 if use_vd_pooling else 3, + (input_conv_kernel_sizes[0] - 1) // 2, data_format=data_format, dtype=dtype, name="conv1_pad", )(x) - if use_vd_pooling: - x = layers.Conv2D( - 32, - 3, - strides=2, - data_format=data_format, - use_bias=False, - dtype=dtype, - name="conv1_conv", - )(x) - x = layers.BatchNormalization( - axis=bn_axis, - epsilon=1e-5, - momentum=0.9, - dtype=dtype, - name="conv1_bn", - )(x) - x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) - x = layers.Conv2D( - 32, - 3, - strides=1, - padding="same", - data_format=data_format, - use_bias=False, - dtype=dtype, - name="conv2_conv", - )(x) + x = layers.Conv2D( + input_conv_filters[0], + input_conv_kernel_sizes[0], + strides=2, + data_format=data_format, + use_bias=False, + padding="valid", + dtype=dtype, + name="conv1_conv", + )(x) + for conv_index in range(1,num_input_convs): x = layers.BatchNormalization( axis=bn_axis, epsilon=1e-5, momentum=0.9, dtype=dtype, - name="conv2_bn", + name=f"conv{conv_index}_bn", )(x) - x = layers.Activation("relu", dtype=dtype, name="conv2_relu")(x) + x = layers.Activation("relu", dtype=dtype, name=f"conv{conv_index}_relu")(x) x = layers.Conv2D( - 64, - 3, + input_conv_filters[conv_index], + input_conv_kernel_sizes[conv_index], strides=1, - padding="same", data_format=data_format, use_bias=False, + padding=1, dtype=dtype, - name="conv3_conv", - )(x) - else: - x = layers.Conv2D( - 64, - 7, - strides=2, - data_format=data_format, - use_bias=False, - dtype=dtype, - name="conv1_conv", + name=f"conv{conv_index+1}_conv", )(x) + if not use_pre_activation: x = layers.BatchNormalization( axis=bn_axis, epsilon=1e-5, momentum=0.9, dtype=dtype, - name="conv3_bn" if use_vd_pooling else "conv1_bn", + name=f"conv{num_input_convs}_bn", )(x) x = layers.Activation( "relu", dtype=dtype, - name="conv3_relu" if use_vd_pooling else "conv1_relu", + name=f"conv{num_input_convs}_relu", )(x) if use_pre_activation: @@ -275,15 +245,13 @@ def __init__( stride=stackwise_num_strides[stack_index], block_type=block_type, use_pre_activation=use_pre_activation, - use_vd_pooling=use_vd_pooling, first_shortcut=( - block_type == "bottleneck_block" + block_type != "basic_block" or stack_index > 0 - or use_vd_pooling ), data_format=data_format, dtype=dtype, - name=f"{version}_stack{stack_index}", + name=f"stack{stack_index}", ) pyramid_outputs[f"P{stack_index + 2}"] = x @@ -321,7 +289,6 @@ def __init__( self.stackwise_num_strides = stackwise_num_strides self.block_type = block_type self.use_pre_activation = use_pre_activation - self.use_vd_pooling = use_vd_pooling self.include_rescaling = include_rescaling self.input_image_shape = input_image_shape self.pooling = pooling @@ -336,7 +303,6 @@ def get_config(self): "stackwise_num_strides": self.stackwise_num_strides, "block_type": self.block_type, "use_pre_activation": self.use_pre_activation, - "use_vd_pooling": self.use_vd_pooling, "include_rescaling": self.include_rescaling, "input_image_shape": self.input_image_shape, "pooling": self.pooling, @@ -352,7 +318,6 @@ def apply_basic_block( stride=1, conv_shortcut=False, use_pre_activation=False, - use_vd_pooling=False, data_format=None, dtype=None, name=None, @@ -370,7 +335,6 @@ def apply_basic_block( `False`. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. Defaults to `False`. - use_vd_pooling: boolean. Whether to use ResNetVd enhancements. data_format: `None` or str. the ordering of the dimensions in the inputs. Can be `"channels_last"` (`(batch_size, height, width, channels)`) or`"channels_first"` @@ -401,7 +365,279 @@ def apply_basic_block( if conv_shortcut: if x_preact is not None: shortcut = x_preact - elif use_vd_pooling and stride > 1: + else: + shortcut = x + shortcut = layers.Conv2D( + filters, + 1, + strides=stride, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + x = layers.Conv2D( + filters, + kernel_size, + strides=1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_bottleneck_block( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a bottleneck residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + else: + shortcut = x + shortcut = layers.Conv2D( + 4 * filters, + 1, + strides=stride, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + x = layers.Conv2D( + filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", + )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + + x = layers.Conv2D( + 4 * filters, + 1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_3_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_basic_block_vd( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a basic residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the basic residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + elif stride > 1: shortcut = layers.AveragePooling2D( 2, strides=stride, @@ -414,7 +650,7 @@ def apply_basic_block( shortcut = layers.Conv2D( filters, 1, - strides=1 if use_vd_pooling else stride, + strides=1, data_format=data_format, use_bias=False, dtype=dtype, @@ -483,14 +719,13 @@ def apply_basic_block( return x -def apply_bottleneck_block( +def apply_bottleneck_block_vd( x, filters, kernel_size=3, stride=1, conv_shortcut=False, use_pre_activation=False, - use_vd_pooling=False, data_format=None, dtype=None, name=None, @@ -508,7 +743,6 @@ def apply_bottleneck_block( `False`. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. Defaults to `False`. - use_vd_pooling: boolean. Whether to use ResNetVd enhancements. data_format: `None` or str. the ordering of the dimensions in the inputs. Can be `"channels_last"` (`(batch_size, height, width, channels)`) or`"channels_first"` @@ -539,7 +773,7 @@ def apply_bottleneck_block( if conv_shortcut: if x_preact is not None: shortcut = x_preact - elif use_vd_pooling and stride > 1: + elif stride > 1: shortcut = layers.AveragePooling2D( 2, strides=stride, @@ -552,7 +786,7 @@ def apply_bottleneck_block( shortcut = layers.Conv2D( 4 * filters, 1, - strides=1 if use_vd_pooling else stride, + strides=1, data_format=data_format, use_bias=False, dtype=dtype, @@ -644,7 +878,6 @@ def apply_stack( stride, block_type, use_pre_activation, - use_vd_pooling=False, first_shortcut=True, data_format=None, dtype=None, @@ -662,7 +895,6 @@ def apply_stack( Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet and ResNeXt. - use_vd_pooling: boolean. Whether to use ResNetVd enhancements. first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, use an identity or pooling shortcut based on the stride. Defaults to `True`. @@ -678,22 +910,21 @@ def apply_stack( Output tensor for the stacked blocks. """ if name is None: - if use_vd_pooling: - version = "vd" - elif use_pre_activation: - version = "v2" - else: - version = "v1" - name = f"{version}_stack" + name = "stack" if block_type == "basic_block": block_fn = apply_basic_block elif block_type == "bottleneck_block": block_fn = apply_bottleneck_block + if block_type == "basic_block_vd": + block_fn = apply_basic_block_vd + elif block_type == "bottleneck_block_vd": + block_fn = apply_bottleneck_block_vd else: raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) for i in range(blocks): if i == 0: @@ -708,7 +939,6 @@ def apply_stack( stride=stride, conv_shortcut=conv_shortcut, use_pre_activation=use_pre_activation, - use_vd_pooling=use_vd_pooling, data_format=data_format, dtype=dtype, name=f"{name}_block{str(i)}", From 745766c72d3b26092cc03efd2e862ef7932115cd Mon Sep 17 00:00:00 2001 From: Gowtham Paimagam Date: Sat, 24 Aug 2024 19:29:20 +0200 Subject: [PATCH 3/4] Fixed tests and updated comments --- .../src/models/resnet/resnet_backbone.py | 49 +++++++++++----- .../src/models/resnet/resnet_backbone_test.py | 56 +++++++++++-------- .../resnet/resnet_image_classifier_test.py | 2 + 3 files changed, 71 insertions(+), 36 deletions(-) diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 7539c39b89..2ad2b517d7 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -54,8 +54,11 @@ class ResNetBackbone(FeaturePyramidBackbone): stackwise_num_strides: list of ints. The number of strides for each stack. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using @@ -95,6 +98,8 @@ class ResNetBackbone(FeaturePyramidBackbone): # Randomly initialized ResNetV2 backbone with a custom config. model = keras_nlp.models.ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2], @@ -122,6 +127,13 @@ def __init__( dtype=None, **kwargs, ): + if len(input_conv_filters) != len(input_conv_kernel_sizes): + raise ValueError( + "The length of `input_conv_filters` and" + "`input_conv_kernel_sizes` must be the same. " + f"Received: input_conv_filters={input_conv_filters}, " + f"input_conv_kernel_sizes={input_conv_kernel_sizes}." + ) if len(stackwise_num_filters) != len(stackwise_num_blocks) or len( stackwise_num_filters ) != len(stackwise_num_strides): @@ -137,7 +149,12 @@ def __init__( "The first element of `stackwise_num_filters` must be 64. " f"Received: stackwise_num_filters={stackwise_num_filters}" ) - if block_type not in ("basic_block", "bottleneck_block", "basic_block_vd", "bottleneck_block_vd"): + if block_type not in ( + "basic_block", + "bottleneck_block", + "basic_block_vd", + "bottleneck_block_vd", + ): raise ValueError( '`block_type` must be either `"basic_block"`, ' '`"bottleneck_block"`, `"basic_block_vd"` or ' @@ -180,7 +197,7 @@ def __init__( dtype=dtype, name="conv1_conv", )(x) - for conv_index in range(1,num_input_convs): + for conv_index in range(1, num_input_convs): x = layers.BatchNormalization( axis=bn_axis, epsilon=1e-5, @@ -188,14 +205,16 @@ def __init__( dtype=dtype, name=f"conv{conv_index}_bn", )(x) - x = layers.Activation("relu", dtype=dtype, name=f"conv{conv_index}_relu")(x) + x = layers.Activation( + "relu", dtype=dtype, name=f"conv{conv_index}_relu" + )(x) x = layers.Conv2D( input_conv_filters[conv_index], input_conv_kernel_sizes[conv_index], strides=1, data_format=data_format, use_bias=False, - padding=1, + padding="same", dtype=dtype, name=f"conv{conv_index+1}_conv", )(x) @@ -245,10 +264,7 @@ def __init__( stride=stackwise_num_strides[stack_index], block_type=block_type, use_pre_activation=use_pre_activation, - first_shortcut=( - block_type != "basic_block" - or stack_index > 0 - ), + first_shortcut=(block_type != "basic_block" or stack_index > 0), data_format=data_format, dtype=dtype, name=f"stack{stack_index}", @@ -284,6 +300,8 @@ def __init__( ) # === Config === + self.input_conv_filters = input_conv_filters + self.input_conv_kernel_sizes = input_conv_kernel_sizes self.stackwise_num_filters = stackwise_num_filters self.stackwise_num_blocks = stackwise_num_blocks self.stackwise_num_strides = stackwise_num_strides @@ -298,6 +316,8 @@ def get_config(self): config = super().get_config() config.update( { + "input_conv_filters": self.input_conv_filters, + "input_conv_kernel_sizes": self.input_conv_kernel_sizes, "stackwise_num_filters": self.stackwise_num_filters, "stackwise_num_blocks": self.stackwise_num_blocks, "stackwise_num_strides": self.stackwise_num_strides, @@ -891,8 +911,11 @@ def apply_stack( blocks: int. The number of blocks in the stack. stride: int. The stride length of the first layer in the first block. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet and ResNeXt. first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, @@ -916,7 +939,7 @@ def apply_stack( block_fn = apply_basic_block elif block_type == "bottleneck_block": block_fn = apply_bottleneck_block - if block_type == "basic_block_vd": + elif block_type == "basic_block_vd": block_fn = apply_basic_block_vd elif block_type == "bottleneck_block_vd": block_fn = apply_bottleneck_block_vd diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 8b738ce8a8..f52800801f 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -24,6 +24,8 @@ class ResNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], "stackwise_num_filters": [64, 64, 64], "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], @@ -34,30 +36,36 @@ def setUp(self): self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) @parameterized.named_parameters( - ("v1_basic", False, False, "basic_block"), - ("v1_bottleneck", False, False, "bottleneck_block"), - ("v2_basic", True, False, "basic_block"), - ("v2_bottleneck", True, False, "bottleneck_block"), - ("vd_basic", False, True, "basic_block"), - ("vd_bottleneck", False, True, "bottleneck_block"), + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) - def test_backbone_basics( - self, use_pre_activation, use_vd_pooling, block_type - ): + def test_backbone_basics(self, use_pre_activation, block_type): init_kwargs = self.init_kwargs.copy() init_kwargs.update( { "block_type": block_type, "use_pre_activation": use_pre_activation, - "use_vd_pooling": use_vd_pooling, } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_vision_backbone_test( cls=ResNetBackbone, init_kwargs=init_kwargs, input_data=self.input_data, expected_output_shape=( - (2, 64) if block_type == "basic_block" else (2, 256) + (2, 64) + if block_type in ("basic_block", "basic_block_vd") + else (2, 256) ), ) @@ -80,28 +88,30 @@ def test_pyramid_output_format(self): self.assertEqual(tuple(v.shape[:3]), (2, size, size)) @parameterized.named_parameters( - ("v1_basic", False, False, "basic_block"), - ("v1_bottleneck", False, False, "bottleneck_block"), - ("v2_basic", True, False, "basic_block"), - ("v2_bottleneck", True, False, "bottleneck_block"), - ("vd_basic", False, True, "basic_block"), - ("vd_bottleneck", False, True, "bottleneck_block"), + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) @pytest.mark.large - def test_saved_model(self, use_pre_activation, use_vd_pooling, block_type): + def test_saved_model(self, use_pre_activation, block_type): init_kwargs = self.init_kwargs.copy() init_kwargs.update( { "block_type": block_type, "use_pre_activation": use_pre_activation, -<<<<<<< HEAD - "use_vd_pooling": use_vd_pooling, - "input_image_shape": (None, None, 3), -======= "image_shape": (None, None, 3), ->>>>>>> upstream/keras-hub } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_model_saving_test( cls=ResNetBackbone, init_kwargs=init_kwargs, diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index 893ec42487..da06c80320 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -26,6 +26,8 @@ def setUp(self): self.images = ops.ones((2, 16, 16, 3)) self.labels = [0, 3] self.backbone = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2], From c28a56dd1e17700c9109e90d80ff804ddabd759d Mon Sep 17 00:00:00 2001 From: Gowtham Paimagam Date: Sat, 24 Aug 2024 19:35:18 +0200 Subject: [PATCH 4/4] Added new parameters to docstring --- keras_nlp/src/models/resnet/resnet_backbone.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 2ad2b517d7..ca1de9b090 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -48,6 +48,10 @@ class ResNetBackbone(FeaturePyramidBackbone): range of `[0, 255]` when `include_rescaling=True`. Args: + input_conv_filters: list of ints. The number of filters of the initial + convolution(s). + input_conv_kernel_sizes: list of ints. The kernel sizes of the initial + convolution(s). stackwise_num_filters: list of ints. The number of filters for each stack. stackwise_num_blocks: list of ints. The number of blocks for each stack.