@@ -817,6 +817,8 @@ def slice_update(inputs, start_indices, updates):
817
817
"`slice_update` is not supported by openvino backend"
818
818
" for `start_indices` of type {}" .format (type (start_indices ))
819
819
)
820
+
821
+ # Convert each start index to int32 scalar tensor and collect them
820
822
processed_start_indices = []
821
823
for idx in start_indices :
822
824
val = get_ov_output (idx )
@@ -828,6 +830,8 @@ def slice_update(inputs, start_indices, updates):
828
830
)
829
831
if val_type != Type .i32 :
830
832
val = ov_opset .convert (val , Type .i32 ).output (0 )
833
+
834
+ # Unsqueeze scalar values to 1D for concat later
831
835
if len (val .get_partial_shape ()) == 0 :
832
836
val = ov_opset .unsqueeze (
833
837
val , ov_opset .constant (0 , Type .i32 )
@@ -836,6 +840,8 @@ def slice_update(inputs, start_indices, updates):
836
840
start_indices_tensor = ov_opset .concat (processed_start_indices , axis = 0 )
837
841
838
842
rank = len (updates .shape )
843
+
844
+ # Generate coordinate ranges for each dimension in the updates
839
845
ranges = []
840
846
for dim in updates .shape :
841
847
r = ov_opset .range (
@@ -846,43 +852,58 @@ def slice_update(inputs, start_indices, updates):
846
852
)
847
853
ranges .append (r )
848
854
855
+ # Broadcast ranges to match shape of updates
849
856
broadcasted_ranges = []
850
857
for i , r in enumerate (ranges ):
851
858
shape = [1 ] * rank
859
+
860
+ # Expand range in the i-th dimension
852
861
shape [i ] = updates .shape [i ]
853
862
r_reshaped = ov_opset .reshape (
854
863
r , ov_opset .constant (shape , Type .i32 ), special_zero = False
855
864
).output (0 )
865
+
866
+ # Broadcast range to the full shape of updates
856
867
target_shape = ov_opset .constant (list (updates .shape ), Type .i32 )
857
868
r_broadcasted = ov_opset .broadcast (r_reshaped , target_shape ).output (0 )
858
869
broadcasted_ranges .append (r_broadcasted )
859
870
871
+ # Stack all broadcasted coordinate grids into shape (rank, ...)
860
872
indices_stack = ov_opset .concat (broadcasted_ranges , axis = 0 ).output (0 )
861
873
874
+ # Flatten to shape (rank, num_updates)
862
875
num_updates = 1
863
876
for dim in updates .shape :
864
877
num_updates *= dim
865
878
new_shape = ov_opset .constant ([rank , num_updates ], Type .i32 )
866
879
indices_reshaped = ov_opset .reshape (
867
880
indices_stack , new_shape , special_zero = False
868
881
).output (0 )
882
+
883
+ # Transpose to shape (num_updates, rank)
869
884
absolute_indices = ov_opset .transpose (
870
885
indices_reshaped , ov_opset .constant ([1 , 0 ], Type .i32 )
871
886
).output (0 )
872
887
888
+ # Broadcast start_indices to (num_updates, rank)
873
889
start_indices_expanded = ov_opset .broadcast (
874
890
start_indices_tensor , ov_opset .constant ([num_updates , rank ], Type .i32 )
875
891
).output (0 )
892
+
893
+ # Compute absolute indices = offset + relative indices
876
894
absolute_indices = ov_opset .add (
877
895
absolute_indices , start_indices_expanded
878
896
).output (0 )
879
897
898
+ # Flatten the updates tensor to (num_updates,)
880
899
updates_tensor = get_ov_output (updates )
881
900
updates_flat = ov_opset .reshape (
882
901
updates_tensor ,
883
902
ov_opset .constant ([num_updates ], Type .i32 ),
884
903
special_zero = False ,
885
904
).output (0 )
905
+
906
+ # Apply the update via scatter_nd_update
886
907
updated = ov_opset .scatter_nd_update (
887
908
inputs , absolute_indices , updates_flat
888
909
).output (0 )
0 commit comments