Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions keras_cv/models/segmentation/segformer/segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.models import MiTBackbone
from keras_cv.models.segmentation.segformer.segformer_presets import ( # noqa: E501
presets,
)
Expand Down Expand Up @@ -165,6 +166,33 @@ def get_config(self):
)
return config

@classmethod
def from_preset(
cls,
preset,
num_classes,
load_weights=None,
input_shape=None,
**kwargs,
):
aliases = {
"segformer_b0": "mit_b0",
"segformer_b1": "mit_b1",
"segformer_b2": "mit_b2",
"segformer_b3": "mit_b3",
"segformer_b4": "mit_b4",
"segformer_b5": "mit_b5",
}
if preset in aliases:
preset = aliases[preset]
return super().from_preset(
preset,
load_weights=load_weights,
num_classes=num_classes,
input_shape=input_shape,
**kwargs,
)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
Expand All @@ -175,3 +203,7 @@ def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return copy.deepcopy(presets_with_weights)

@classproperty
def backbone_presets(cls):
return copy.deepcopy(MiTBackbone.presets)
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
}

presets = {
**backbone_presets, # Add MiTBackbone presets
**presets_no_weights,
**presets_with_weights,
}
55 changes: 52 additions & 3 deletions keras_cv/models/segmentation/segformer/segformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,37 @@ def test_segformer_construction(self):
metrics=["accuracy"],
)

def test_segformer_preset_construction(self):
model = SegFormer.from_preset(
"segformer_b0", num_classes=1, input_shape=[512, 512, 3]
)
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(),
metrics=["accuracy"],
)

def test_segformer_preset_error(self):
with self.assertRaises(TypeError):
_ = SegFormer.from_preset("segformer_b0")

@pytest.mark.large
def test_segformer_call(self):
backbone = MiTBackbone.from_preset("mit_b0", input_shape=[512, 512, 3])
model = SegFormer(backbone=backbone, num_classes=1)
mit_model = SegFormer(backbone=backbone, num_classes=1)

images = np.random.uniform(size=(2, 512, 512, 3))
_ = model(images)
_ = model.predict(images)
mit_output = mit_model(images)
mit_pred = mit_model.predict(images)

seg_model = SegFormer.from_preset(
"segformer_b0", num_classes=1, input_shape=[512, 512, 3]
)
seg_output = seg_model(images)
seg_pred = seg_model.predict(images)

self.assertAllClose(mit_output, seg_output)
self.assertAllClose(mit_pred, seg_pred)

@pytest.mark.large
def test_weights_change(self):
Expand Down Expand Up @@ -94,3 +118,28 @@ def test_saved_model(self):
# Check that output matches.
restored_output = restored_model(input_batch)
self.assertAllClose(model_output, restored_output)

@pytest.mark.large # Saving is slow, so mark these large.
def test_preset_saved_model(self):
target_size = [512, 512, 3]

model = SegFormer.from_preset(
"segformer_b0", num_classes=1, input_shape=[512, 512, 3]
)

input_batch = np.ones(shape=[2] + target_size)
model_output = model(input_batch)

save_path = os.path.join(self.get_temp_dir(), "model.keras")
if keras_3():
model.save(save_path)
else:
model.save(save_path, save_format="keras_v3")
restored_model = keras.models.load_model(save_path)

# Check we got the real object back.
self.assertIsInstance(restored_model, SegFormer)

# Check that output matches.
restored_output = restored_model(input_batch)
self.assertAllClose(model_output, restored_output)
7 changes: 6 additions & 1 deletion keras_cv/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def from_preset(
backbone_cls = keras.saving.get_registered_object(
metadata["class_name"]
)
backbone = backbone_cls.from_preset(preset, load_weights)
backbone_kwargs = {}
if input_shape is not None:
backbone_kwargs["input_shape"] = input_shape
backbone = backbone_cls.from_preset(
preset, load_weights, **backbone_kwargs
)
return cls(backbone, **kwargs)

# Otherwise must be one of class presets
Expand Down