Skip to content

Commit a4188f1

Browse files
add comments for clarification
1 parent a285aa7 commit a4188f1

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

keras/src/backend/openvino/core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,8 @@ def slice_update(inputs, start_indices, updates):
817817
"`slice_update` is not supported by openvino backend"
818818
" for `start_indices` of type {}".format(type(start_indices))
819819
)
820+
821+
# Convert each start index to int32 scalar tensor and collect them
820822
processed_start_indices = []
821823
for idx in start_indices:
822824
val = get_ov_output(idx)
@@ -828,6 +830,8 @@ def slice_update(inputs, start_indices, updates):
828830
)
829831
if val_type != Type.i32:
830832
val = ov_opset.convert(val, Type.i32).output(0)
833+
834+
# Unsqueeze scalar values to 1D for concat later
831835
if len(val.get_partial_shape()) == 0:
832836
val = ov_opset.unsqueeze(
833837
val, ov_opset.constant(0, Type.i32)
@@ -836,6 +840,8 @@ def slice_update(inputs, start_indices, updates):
836840
start_indices_tensor = ov_opset.concat(processed_start_indices, axis=0)
837841

838842
rank = len(updates.shape)
843+
844+
# Generate coordinate ranges for each dimension in the updates
839845
ranges = []
840846
for dim in updates.shape:
841847
r = ov_opset.range(
@@ -846,43 +852,58 @@ def slice_update(inputs, start_indices, updates):
846852
)
847853
ranges.append(r)
848854

855+
# Broadcast ranges to match shape of updates
849856
broadcasted_ranges = []
850857
for i, r in enumerate(ranges):
851858
shape = [1] * rank
859+
860+
# Expand range in the i-th dimension
852861
shape[i] = updates.shape[i]
853862
r_reshaped = ov_opset.reshape(
854863
r, ov_opset.constant(shape, Type.i32), special_zero=False
855864
).output(0)
865+
866+
# Broadcast range to the full shape of updates
856867
target_shape = ov_opset.constant(list(updates.shape), Type.i32)
857868
r_broadcasted = ov_opset.broadcast(r_reshaped, target_shape).output(0)
858869
broadcasted_ranges.append(r_broadcasted)
859870

871+
# Stack all broadcasted coordinate grids into shape (rank, ...)
860872
indices_stack = ov_opset.concat(broadcasted_ranges, axis=0).output(0)
861873

874+
# Flatten to shape (rank, num_updates)
862875
num_updates = 1
863876
for dim in updates.shape:
864877
num_updates *= dim
865878
new_shape = ov_opset.constant([rank, num_updates], Type.i32)
866879
indices_reshaped = ov_opset.reshape(
867880
indices_stack, new_shape, special_zero=False
868881
).output(0)
882+
883+
# Transpose to shape (num_updates, rank)
869884
absolute_indices = ov_opset.transpose(
870885
indices_reshaped, ov_opset.constant([1, 0], Type.i32)
871886
).output(0)
872887

888+
# Broadcast start_indices to (num_updates, rank)
873889
start_indices_expanded = ov_opset.broadcast(
874890
start_indices_tensor, ov_opset.constant([num_updates, rank], Type.i32)
875891
).output(0)
892+
893+
# Compute absolute indices = offset + relative indices
876894
absolute_indices = ov_opset.add(
877895
absolute_indices, start_indices_expanded
878896
).output(0)
879897

898+
# Flatten the updates tensor to (num_updates,)
880899
updates_tensor = get_ov_output(updates)
881900
updates_flat = ov_opset.reshape(
882901
updates_tensor,
883902
ov_opset.constant([num_updates], Type.i32),
884903
special_zero=False,
885904
).output(0)
905+
906+
# Apply the update via scatter_nd_update
886907
updated = ov_opset.scatter_nd_update(
887908
inputs, absolute_indices, updates_flat
888909
).output(0)

0 commit comments

Comments
 (0)