Skip to content

Commit d04fbcc

Browse files
authored
Fix CI Test for Basnet OOM and PyCoCo Test Failure for JAX (#2322)
1 parent 813d43d commit d04fbcc

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pip install --no-deps -e "." --progress-bar off
5151
# Run Extra Large Tests for Continuous builds
5252
if [ "${RUN_XLARGE:-0}" == "1" ]
5353
then
54-
pytest --check_gpu --run_large --run_extra_large --durations 0 \
54+
pytest --cache-clear --check_gpu --run_large --run_extra_large --durations 0 \
5555
keras_cv/bounding_box \
5656
keras_cv/callbacks \
5757
keras_cv/losses \
@@ -65,7 +65,7 @@ then
6565
keras_cv/models/segmentation \
6666
keras_cv/models/stable_diffusion
6767
else
68-
pytest --check_gpu --run_large --durations 0 \
68+
pytest --cache-clear --check_gpu --run_large --durations 0 \
6969
keras_cv/bounding_box \
7070
keras_cv/callbacks \
7171
keras_cv/losses \

keras_cv/metrics/coco/pycoco_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ def _convert_predictions_to_coco_annotations(predictions):
125125
num_batches = len(predictions["source_id"])
126126
for i in range(num_batches):
127127
batch_size = predictions["source_id"][i].shape[0]
128+
predictions["detection_boxes"][i] = predictions["detection_boxes"][
129+
i
130+
].copy()
128131
for j in range(batch_size):
129132
max_num_detections = predictions["num_detections"][i][j]
130133
predictions["detection_boxes"][i][j] = _yxyx_to_xywh(

keras_cv/models/segmentation/basnet/basnet_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import os
1617

1718
import numpy as np
@@ -23,13 +24,13 @@
2324
from keras_cv.backend import ops
2425
from keras_cv.backend.config import keras_3
2526
from keras_cv.models import BASNet
26-
from keras_cv.models import ResNet34Backbone
27+
from keras_cv.models import ResNet18Backbone
2728
from keras_cv.tests.test_case import TestCase
2829

2930

3031
class BASNetTest(TestCase):
3132
def test_basnet_construction(self):
32-
backbone = ResNet34Backbone()
33+
backbone = ResNet18Backbone()
3334
model = BASNet(
3435
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
3536
)
@@ -41,7 +42,7 @@ def test_basnet_construction(self):
4142

4243
@pytest.mark.large
4344
def test_basnet_call(self):
44-
backbone = ResNet34Backbone()
45+
backbone = ResNet18Backbone()
4546
model = BASNet(
4647
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
4748
)
@@ -61,7 +62,7 @@ def test_weights_change(self):
6162
ds = ds.repeat(2)
6263
ds = ds.batch(2)
6364

64-
backbone = ResNet34Backbone()
65+
backbone = ResNet18Backbone()
6566
model = BASNet(
6667
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
6768
)
@@ -99,7 +100,7 @@ def test_with_model_preset_forward_pass(self):
99100
def test_saved_model(self):
100101
target_size = [288, 288, 3]
101102

102-
backbone = ResNet34Backbone()
103+
backbone = ResNet18Backbone()
103104
model = BASNet(
104105
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
105106
)
@@ -112,6 +113,9 @@ def test_saved_model(self):
112113
model.save(save_path)
113114
else:
114115
model.save(save_path, save_format="keras_v3")
116+
# Free up model memory
117+
del model
118+
gc.collect()
115119
restored_model = keras.models.load_model(save_path)
116120

117121
# Check we got the real object back.

0 commit comments

Comments
 (0)