From 35dc5850cbe4c9d3dde001f13e96deb34bdc67ef Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Thu, 7 Dec 2023 20:11:13 +0000 Subject: [PATCH 1/4] Fix segformer presets --- .../segmentation/segformer/segformer.py | 24 +++++++++++++ .../segmentation/segformer/segformer_test.py | 35 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/keras_cv/models/segmentation/segformer/segformer.py b/keras_cv/models/segmentation/segformer/segformer.py index ccd033bd9f..f770ae4039 100644 --- a/keras_cv/models/segmentation/segformer/segformer.py +++ b/keras_cv/models/segmentation/segformer/segformer.py @@ -16,6 +16,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras +from keras_cv.models import MiTBackbone from keras_cv.models.segmentation.segformer.segformer_presets import ( # noqa: E501 presets, ) @@ -165,6 +166,29 @@ def get_config(self): ) return config + @classmethod + def from_preset( + cls, + preset, + num_classes, + load_weights=None, + **kwargs, + ): + aliases = { + "segformer_b0": "mit_b0", + "segformer_b1": "mit_b1", + "segformer_b2": "mit_b2", + "segformer_b3": "mit_b3", + "segformer_b4": "mit_b4", + "segformer_b5": "mit_b5", + } + if preset in aliases: + preset = aliases[preset] + backbone = MiTBackbone.from_preset( + preset, load_weights=load_weights, **kwargs + ) + return cls(backbone=backbone, num_classes=num_classes) + @classproperty def presets(cls): """Dictionary of preset names and configurations.""" diff --git a/keras_cv/models/segmentation/segformer/segformer_test.py b/keras_cv/models/segmentation/segformer/segformer_test.py index b30786ac0b..e78de33d22 100644 --- a/keras_cv/models/segmentation/segformer/segformer_test.py +++ b/keras_cv/models/segmentation/segformer/segformer_test.py @@ -36,6 +36,16 @@ def test_segformer_construction(self): metrics=["accuracy"], ) + def test_segformer_preset_construction(self): + model = SegFormer.from_preset( + "segformer_b0", num_classes=1, input_shape=[512, 512, 3] + ) + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(), + metrics=["accuracy"], + ) + @pytest.mark.large def test_segformer_call(self): backbone = MiTBackbone.from_preset("mit_b0", input_shape=[512, 512, 3]) @@ -94,3 +104,28 @@ def test_saved_model(self): # Check that output matches. restored_output = restored_model(input_batch) self.assertAllClose(model_output, restored_output) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_preset_saved_model(self): + target_size = [512, 512, 3] + + model = SegFormer.from_preset( + "segformer_b0", num_classes=1, input_shape=[512, 512, 3] + ) + + input_batch = np.ones(shape=[2] + target_size) + model_output = model(input_batch) + + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, SegFormer) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose(model_output, restored_output) From 0fa84bca22b0ed4e97b1ed6f5ee807d630163969 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Thu, 7 Dec 2023 21:56:47 +0000 Subject: [PATCH 2/4] Add extra SegFormer tests --- .../segmentation/segformer/segformer_test.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/keras_cv/models/segmentation/segformer/segformer_test.py b/keras_cv/models/segmentation/segformer/segformer_test.py index e78de33d22..e8bbd332a6 100644 --- a/keras_cv/models/segmentation/segformer/segformer_test.py +++ b/keras_cv/models/segmentation/segformer/segformer_test.py @@ -46,13 +46,27 @@ def test_segformer_preset_construction(self): metrics=["accuracy"], ) + def test_segformer_preset_error(self): + with self.assertRaises(TypeError): + model = SegFormer.from_preset("segformer_b0") + @pytest.mark.large def test_segformer_call(self): backbone = MiTBackbone.from_preset("mit_b0", input_shape=[512, 512, 3]) - model = SegFormer(backbone=backbone, num_classes=1) + mit_model = SegFormer(backbone=backbone, num_classes=1) + images = np.random.uniform(size=(2, 512, 512, 3)) - _ = model(images) - _ = model.predict(images) + mit_output = mit_model(images) + mit_pred = mit_model.predict(images) + + seg_model = SegFormer.from_preset( + "segformer_b0", num_classes=1, input_shape=[512, 512, 3] + ) + seg_output = seg_model(images) + seg_pred = seg_model.predict(images) + + self.assertAllClose(mit_output, seg_output) + self.assertAllClose(mit_pred, seg_pred) @pytest.mark.large def test_weights_change(self): From c905b195b253f36fea22fe99eff24564ec409edf Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Thu, 7 Dec 2023 22:33:21 +0000 Subject: [PATCH 3/4] Change flow, add backbone presets to SegFormer presets --- .../models/segmentation/segformer/segformer.py | 14 +++++++++++--- .../segmentation/segformer/segformer_presets.py | 1 + keras_cv/models/task.py | 7 ++++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/keras_cv/models/segmentation/segformer/segformer.py b/keras_cv/models/segmentation/segformer/segformer.py index f770ae4039..59e4e38f56 100644 --- a/keras_cv/models/segmentation/segformer/segformer.py +++ b/keras_cv/models/segmentation/segformer/segformer.py @@ -172,6 +172,7 @@ def from_preset( preset, num_classes, load_weights=None, + input_shape=None, **kwargs, ): aliases = { @@ -184,10 +185,13 @@ def from_preset( } if preset in aliases: preset = aliases[preset] - backbone = MiTBackbone.from_preset( - preset, load_weights=load_weights, **kwargs + return super().from_preset( + preset, + load_weights=load_weights, + num_classes=num_classes, + input_shape=input_shape, + **kwargs, ) - return cls(backbone=backbone, num_classes=num_classes) @classproperty def presets(cls): @@ -199,3 +203,7 @@ def presets_with_weights(cls): """Dictionary of preset names and configurations that include weights.""" return copy.deepcopy(presets_with_weights) + + @classproperty + def backbone_presets(cls): + return copy.deepcopy(MiTBackbone.presets) diff --git a/keras_cv/models/segmentation/segformer/segformer_presets.py b/keras_cv/models/segmentation/segformer/segformer_presets.py index e19e2ec9ba..6f01a82d4d 100644 --- a/keras_cv/models/segmentation/segformer/segformer_presets.py +++ b/keras_cv/models/segmentation/segformer/segformer_presets.py @@ -100,6 +100,7 @@ } presets = { + **backbone_presets, # Add MiTBackbone presets **presets_no_weights, **presets_with_weights, } diff --git a/keras_cv/models/task.py b/keras_cv/models/task.py index 906b02d3ad..684a479778 100644 --- a/keras_cv/models/task.py +++ b/keras_cv/models/task.py @@ -137,7 +137,12 @@ def from_preset( backbone_cls = keras.saving.get_registered_object( metadata["class_name"] ) - backbone = backbone_cls.from_preset(preset, load_weights) + backbone_kwargs = {} + if input_shape is not None: + backbone_kwargs["input_shape"] = input_shape + backbone = backbone_cls.from_preset( + preset, load_weights, **backbone_kwargs + ) return cls(backbone, **kwargs) # Otherwise must be one of class presets From c7020bee25d501ea19004e7e782161351dcb6aea Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Thu, 7 Dec 2023 22:59:23 +0000 Subject: [PATCH 4/4] Fix small formatting issue --- keras_cv/models/segmentation/segformer/segformer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/models/segmentation/segformer/segformer_test.py b/keras_cv/models/segmentation/segformer/segformer_test.py index e8bbd332a6..233df20b9c 100644 --- a/keras_cv/models/segmentation/segformer/segformer_test.py +++ b/keras_cv/models/segmentation/segformer/segformer_test.py @@ -48,7 +48,7 @@ def test_segformer_preset_construction(self): def test_segformer_preset_error(self): with self.assertRaises(TypeError): - model = SegFormer.from_preset("segformer_b0") + _ = SegFormer.from_preset("segformer_b0") @pytest.mark.large def test_segformer_call(self):