Skip to content

Commit 7b4a78c

Browse files
Solving issue #20221 (#20237)
* added validation checks and raised error if an invalid input shape is passed to compute_output_shape func in UnitNormalization Layer * updated my change to check if the input is int or an iterable before iterating * Update unit_normalization.py --------- Co-authored-by: François Chollet <[email protected]>
1 parent d60dd6c commit 7b4a78c

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

keras/src/layers/normalization/unit_normalization.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ def call(self, inputs):
4343
return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12)
4444

4545
def compute_output_shape(self, input_shape):
46+
# Ensure axis is always treated as a list
47+
if isinstance(self.axis, int):
48+
axes = [self.axis]
49+
else:
50+
axes = self.axis
51+
52+
for axis in axes:
53+
if axis >= len(input_shape) or axis < -len(input_shape):
54+
raise ValueError(
55+
f"Axis {self.axis} is out of bounds for "
56+
f"input shape {input_shape}."
57+
)
4658
return input_shape
4759

4860
def get_config(self):

0 commit comments

Comments
 (0)