-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[OpenVINO Backend] support ops.slice_update #21362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
aef0743
0f78f8e
a285aa7
62d02fb
b7c27dd
bc166fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -810,10 +810,170 @@ def prepare_slice_index(val): | |
|
||
|
||
def slice_update(inputs, start_indices, updates): | ||
raise NotImplementedError( | ||
"`slice_update` is not supported with openvino backend" | ||
inputs = get_ov_output(inputs) | ||
updates_tensor = get_ov_output(updates) | ||
|
||
if isinstance(start_indices, (list, np.ndarray)): | ||
start_indices = tuple(start_indices) | ||
assert isinstance(start_indices, tuple), ( | ||
"`slice_update` is not supported by openvino backend" | ||
" for `start_indices` of type {}".format(type(start_indices)) | ||
) | ||
|
||
zero_scalar = ov_opset.constant(0, Type.i32) | ||
one_scalar = ov_opset.constant(1, Type.i32) | ||
zero_tensor = ov_opset.constant([0], Type.i32) | ||
one_tensor = ov_opset.constant([1], Type.i32) | ||
|
||
def process_index(idx): | ||
val = get_ov_output(idx) | ||
if not val.get_element_type().is_integral(): | ||
raise ValueError("`slice_update` requires integral start_indices") | ||
if val.get_element_type() != Type.i32: | ||
val = ov_opset.convert(val, Type.i32).output(0) | ||
if len(val.get_partial_shape()) == 0: | ||
val = ov_opset.unsqueeze(val, zero_scalar).output(0) | ||
return val | ||
|
||
def create_meshgrid_indices(): | ||
all_indices = [] | ||
|
||
for dim_idx in range(rank): | ||
dim_size = ov_opset.gather( | ||
updates_shape, | ||
ov_opset.constant([dim_idx], Type.i32), | ||
zero_scalar, | ||
).output(0) | ||
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0) | ||
|
||
# Create range for this dimension | ||
dim_range = ov_opset.range( | ||
zero_scalar, dim_size_scalar, one_scalar, output_type=Type.i32 | ||
).output(0) | ||
|
||
# Calculate factors for meshgrid | ||
if dim_idx < rank - 1: | ||
# Product of dimensions after current | ||
remaining_shape = ov_opset.slice( | ||
updates_shape, | ||
ov_opset.constant([dim_idx + 1], Type.i32), | ||
ov_opset.constant([rank], Type.i32), | ||
one_tensor, | ||
zero_tensor, | ||
).output(0) | ||
repeat_factor = ov_opset.reduce_prod( | ||
remaining_shape, zero_tensor, keep_dims=False | ||
).output(0) | ||
else: | ||
repeat_factor = one_scalar | ||
|
||
if dim_idx > 0: | ||
# Product of dimensions before current | ||
preceding_shape = ov_opset.slice( | ||
updates_shape, | ||
zero_tensor, | ||
ov_opset.constant([dim_idx], Type.i32), | ||
one_tensor, | ||
zero_tensor, | ||
).output(0) | ||
tile_factor = ov_opset.reduce_prod( | ||
preceding_shape, zero_tensor, keep_dims=False | ||
).output(0) | ||
else: | ||
tile_factor = one_scalar | ||
|
||
# Apply meshgrid transformations | ||
# Step 1: Repeat elements | ||
dim_expanded = ov_opset.unsqueeze(dim_range, one_scalar).output(0) | ||
repeat_shape = ov_opset.concat( | ||
[ | ||
one_tensor, | ||
ov_opset.unsqueeze(repeat_factor, zero_scalar).output(0), | ||
], | ||
axis=0, | ||
).output(0) | ||
dim_repeated = ov_opset.tile(dim_expanded, repeat_shape).output(0) | ||
|
||
# Step 2: Flatten | ||
flat_size = ov_opset.multiply( | ||
dim_size_scalar, repeat_factor | ||
).output(0) | ||
dim_flat = ov_opset.reshape( | ||
dim_repeated, | ||
ov_opset.unsqueeze(flat_size, zero_scalar).output(0), | ||
special_zero=False, | ||
).output(0) | ||
|
||
# Step 3: Tile entire sequence | ||
final_indices = ov_opset.tile( | ||
dim_flat, ov_opset.unsqueeze(tile_factor, zero_scalar).output(0) | ||
).output(0) | ||
|
||
all_indices.append(final_indices) | ||
|
||
return all_indices | ||
Comment on lines
+838
to
+914
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
processed_start_indices = [process_index(idx) for idx in start_indices] | ||
|
||
updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0) | ||
rank = updates_tensor.get_partial_shape().rank.get_length() | ||
num_updates = ov_opset.reduce_prod( | ||
updates_shape, zero_tensor, keep_dims=False | ||
).output(0) | ||
Comment on lines
+918
to
+922
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation assumes that the rank of the |
||
|
||
all_indices = create_meshgrid_indices() | ||
|
||
# Stack and reshape indices | ||
indices_stack = ov_opset.concat(all_indices, axis=0).output(0) | ||
indices_matrix = ov_opset.reshape( | ||
indices_stack, | ||
ov_opset.concat( | ||
[ | ||
ov_opset.constant([rank], Type.i32), | ||
ov_opset.unsqueeze(num_updates, zero_scalar).output(0), | ||
], | ||
axis=0, | ||
).output(0), | ||
special_zero=False, | ||
).output(0) | ||
|
||
# Transpose to [num_updates, rank] | ||
relative_indices = ov_opset.transpose( | ||
indices_matrix, ov_opset.constant([1, 0], Type.i32) | ||
).output(0) | ||
|
||
# Add start indices offset | ||
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0) | ||
start_broadcast = ov_opset.tile( | ||
ov_opset.reshape( | ||
start_tensor, | ||
ov_opset.constant([1, rank], Type.i32), | ||
special_zero=False, | ||
).output(0), | ||
ov_opset.concat( | ||
[ | ||
ov_opset.unsqueeze(num_updates, zero_scalar).output(0), | ||
one_tensor, | ||
], | ||
axis=0, | ||
).output(0), | ||
).output(0) | ||
|
||
absolute_indices = ov_opset.add(relative_indices, start_broadcast).output(0) | ||
|
||
# Flatten updates to a 1D tensor for scatter_nd_update compatibility. | ||
updates_flat = ov_opset.reshape( | ||
updates_tensor, | ||
ov_opset.unsqueeze(num_updates, zero_scalar).output(0), | ||
special_zero=False, | ||
).output(0) | ||
|
||
# Scatter updates into input tensor at target indices. | ||
result = ov_opset.scatter_nd_update( | ||
inputs, absolute_indices, updates_flat | ||
).output(0) | ||
return OpenVINOKerasTensor(result) | ||
|
||
|
||
def while_loop( | ||
cond, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add comments for each block