-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[OpenVINO Backend] support ops.slice_update #21362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[OpenVINO Backend] support ops.slice_update #21362
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21362 +/- ##
==========================================
+ Coverage 82.73% 82.75% +0.01%
==========================================
Files 565 565
Lines 55215 55354 +139
Branches 8606 8631 +25
==========================================
+ Hits 45682 45807 +125
- Misses 7424 7440 +16
+ Partials 2109 2107 -2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
c968548
to
004626b
Compare
004626b
to
aef0743
Compare
let us wait until |
@rkazants |
@@ -810,9 +810,83 @@ def prepare_slice_index(val): | |||
|
|||
|
|||
def slice_update(inputs, start_indices, updates): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add comments for each block
799ecfa
to
7ef40d5
Compare
@rkazants |
307110e
to
a4188f1
Compare
a4188f1
to
62d02fb
Compare
9a2a8d8
to
b7c27dd
Compare
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request implements the slice_update
operation for the OpenVINO backend. The implementation generates indices for all elements in the updates
tensor and uses scatter_nd_update
to perform the update. The changes also enable the corresponding tests for this operation. The implementation assumes a static rank for the input tensors, and the index generation function could be refactored for better readability.
updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0) | ||
rank = updates_tensor.get_partial_shape().rank.get_length() | ||
num_updates = ov_opset.reduce_prod( | ||
updates_shape, zero_tensor, keep_dims=False | ||
).output(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation assumes that the rank of the updates
tensor is statically known. The call to .get_length()
on this line will raise an exception if the rank is dynamic. Keras layers and backends should ideally support dynamic shapes, including dynamic ranks. The logic within create_meshgrid_indices
and process_index
also relies on a static rank. Please consider how to support dynamic ranks, which might involve using OpenVINO's control flow operations (like ov.opset.loop
) instead of a Python for
loop over the rank. If supporting dynamic rank is out of scope for this PR, this limitation should be documented in a docstring for the function.
def create_meshgrid_indices(): | ||
all_indices = [] | ||
|
||
for dim_idx in range(rank): | ||
dim_size = ov_opset.gather( | ||
updates_shape, | ||
ov_opset.constant([dim_idx], Type.i32), | ||
zero_scalar, | ||
).output(0) | ||
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0) | ||
|
||
# Create range for this dimension | ||
dim_range = ov_opset.range( | ||
zero_scalar, dim_size_scalar, one_scalar, output_type=Type.i32 | ||
).output(0) | ||
|
||
# Calculate factors for meshgrid | ||
if dim_idx < rank - 1: | ||
# Product of dimensions after current | ||
remaining_shape = ov_opset.slice( | ||
updates_shape, | ||
ov_opset.constant([dim_idx + 1], Type.i32), | ||
ov_opset.constant([rank], Type.i32), | ||
one_tensor, | ||
zero_tensor, | ||
).output(0) | ||
repeat_factor = ov_opset.reduce_prod( | ||
remaining_shape, zero_tensor, keep_dims=False | ||
).output(0) | ||
else: | ||
repeat_factor = one_scalar | ||
|
||
if dim_idx > 0: | ||
# Product of dimensions before current | ||
preceding_shape = ov_opset.slice( | ||
updates_shape, | ||
zero_tensor, | ||
ov_opset.constant([dim_idx], Type.i32), | ||
one_tensor, | ||
zero_tensor, | ||
).output(0) | ||
tile_factor = ov_opset.reduce_prod( | ||
preceding_shape, zero_tensor, keep_dims=False | ||
).output(0) | ||
else: | ||
tile_factor = one_scalar | ||
|
||
# Apply meshgrid transformations | ||
# Step 1: Repeat elements | ||
dim_expanded = ov_opset.unsqueeze(dim_range, one_scalar).output(0) | ||
repeat_shape = ov_opset.concat( | ||
[ | ||
one_tensor, | ||
ov_opset.unsqueeze(repeat_factor, zero_scalar).output(0), | ||
], | ||
axis=0, | ||
).output(0) | ||
dim_repeated = ov_opset.tile(dim_expanded, repeat_shape).output(0) | ||
|
||
# Step 2: Flatten | ||
flat_size = ov_opset.multiply( | ||
dim_size_scalar, repeat_factor | ||
).output(0) | ||
dim_flat = ov_opset.reshape( | ||
dim_repeated, | ||
ov_opset.unsqueeze(flat_size, zero_scalar).output(0), | ||
special_zero=False, | ||
).output(0) | ||
|
||
# Step 3: Tile entire sequence | ||
final_indices = ov_opset.tile( | ||
dim_flat, ov_opset.unsqueeze(tile_factor, zero_scalar).output(0) | ||
).output(0) | ||
|
||
all_indices.append(final_indices) | ||
|
||
return all_indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The create_meshgrid_indices
function is complex, which can make it difficult to read and maintain. Consider refactoring this function by extracting parts of the logic into smaller helper functions. For example, the logic for calculating repeat_factor
and tile_factor
, and the meshgrid transformation steps (repeating, flattening, tiling) could each be in their own helper functions. This would make the main loop in create_meshgrid_indices
more declarative and easier to understand.
Hi @rkazants ,
I've supported
ops,silce_update
for my GSoC project, but I can't enable tests for it until getting__getitem__
merge.It could be implemented more easily if we use
opset15
instead ofopset14
.