Skip to content

Commit ccfc7e2

Browse files
authored
Update the comment of the num_classes parameter of deeplab v3 (keras-team#2071)
* Update deeplab_v3_plus.py Update the comment of the `num_classes`parameter which contains the background class and the classes from the data. * Update deeplab_v3_plus_test.py Update the tests following the updating of 'num_classes' parameter (now including the background class)
1 parent 0202b30 commit ccfc7e2

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

keras_cv/models/segmentation/deeplab_v3_plus/deeplab_v3_plus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class DeepLabV3Plus(Task):
4646
somewhat sensible backbone to use in many cases is the
4747
`keras_cv.models.ResNet50V2Backbone.from_preset("resnet50_v2_imagenet")`.
4848
num_classes: int, the number of classes for the detection model. Note
49-
that the `num_classes` doesn't contain the background class, and the
49+
that the `num_classes` contains the background class, and the
5050
classes from the data should be represented by integers with range
5151
[0, `num_classes`).
5252
projection_filters: int, number of filters in the convolution layer

keras_cv/models/segmentation/deeplab_v3_plus/deeplab_v3_plus_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class DeepLabV3PlusTest(TestCase):
3333
def test_deeplab_v3_plus_construction(self):
3434
backbone = ResNet18V2Backbone(input_shape=[512, 512, 3])
35-
model = DeepLabV3Plus(backbone=backbone, num_classes=1)
35+
model = DeepLabV3Plus(backbone=backbone, num_classes=2)
3636
model.compile(
3737
optimizer="adam",
3838
loss=keras.losses.BinaryCrossentropy(),
@@ -42,7 +42,7 @@ def test_deeplab_v3_plus_construction(self):
4242
@pytest.mark.large
4343
def test_deeplab_v3_plus_call(self):
4444
backbone = ResNet18V2Backbone(input_shape=[512, 512, 3])
45-
model = DeepLabV3Plus(backbone=backbone, num_classes=1)
45+
model = DeepLabV3Plus(backbone=backbone, num_classes=2)
4646
images = np.random.uniform(size=(2, 512, 512, 3))
4747
_ = model(images)
4848
_ = model.predict(images)
@@ -83,7 +83,7 @@ def test_saved_model(self, save_format, filename):
8383
target_size = [512, 512, 3]
8484

8585
backbone = ResNet18V2Backbone(input_shape=target_size)
86-
model = DeepLabV3Plus(backbone=backbone, num_classes=1)
86+
model = DeepLabV3Plus(backbone=backbone, num_classes=2)
8787

8888
input_batch = np.ones(shape=[2] + target_size)
8989
model_output = model(input_batch)

0 commit comments

Comments
 (0)