Skip to content

[OpenVINO Backend] support ops.slice_update #21362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 162 additions & 2 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,10 +810,170 @@ def prepare_slice_index(val):


def slice_update(inputs, start_indices, updates):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add comments for each block

raise NotImplementedError(
"`slice_update` is not supported with openvino backend"
inputs = get_ov_output(inputs)
updates_tensor = get_ov_output(updates)

if isinstance(start_indices, (list, np.ndarray)):
start_indices = tuple(start_indices)
assert isinstance(start_indices, tuple), (
"`slice_update` is not supported by openvino backend"
" for `start_indices` of type {}".format(type(start_indices))
)

zero_scalar = ov_opset.constant(0, Type.i32)
one_scalar = ov_opset.constant(1, Type.i32)
zero_tensor = ov_opset.constant([0], Type.i32)
one_tensor = ov_opset.constant([1], Type.i32)

def process_index(idx):
val = get_ov_output(idx)
if not val.get_element_type().is_integral():
raise ValueError("`slice_update` requires integral start_indices")
if val.get_element_type() != Type.i32:
val = ov_opset.convert(val, Type.i32).output(0)
if len(val.get_partial_shape()) == 0:
val = ov_opset.unsqueeze(val, zero_scalar).output(0)
return val

def create_meshgrid_indices():
all_indices = []

for dim_idx in range(rank):
dim_size = ov_opset.gather(
updates_shape,
ov_opset.constant([dim_idx], Type.i32),
zero_scalar,
).output(0)
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0)

# Create range for this dimension
dim_range = ov_opset.range(
zero_scalar, dim_size_scalar, one_scalar, output_type=Type.i32
).output(0)

# Calculate factors for meshgrid
if dim_idx < rank - 1:
# Product of dimensions after current
remaining_shape = ov_opset.slice(
updates_shape,
ov_opset.constant([dim_idx + 1], Type.i32),
ov_opset.constant([rank], Type.i32),
one_tensor,
zero_tensor,
).output(0)
repeat_factor = ov_opset.reduce_prod(
remaining_shape, zero_tensor, keep_dims=False
).output(0)
else:
repeat_factor = one_scalar

if dim_idx > 0:
# Product of dimensions before current
preceding_shape = ov_opset.slice(
updates_shape,
zero_tensor,
ov_opset.constant([dim_idx], Type.i32),
one_tensor,
zero_tensor,
).output(0)
tile_factor = ov_opset.reduce_prod(
preceding_shape, zero_tensor, keep_dims=False
).output(0)
else:
tile_factor = one_scalar

# Apply meshgrid transformations
# Step 1: Repeat elements
dim_expanded = ov_opset.unsqueeze(dim_range, one_scalar).output(0)
repeat_shape = ov_opset.concat(
[
one_tensor,
ov_opset.unsqueeze(repeat_factor, zero_scalar).output(0),
],
axis=0,
).output(0)
dim_repeated = ov_opset.tile(dim_expanded, repeat_shape).output(0)

# Step 2: Flatten
flat_size = ov_opset.multiply(
dim_size_scalar, repeat_factor
).output(0)
dim_flat = ov_opset.reshape(
dim_repeated,
ov_opset.unsqueeze(flat_size, zero_scalar).output(0),
special_zero=False,
).output(0)

# Step 3: Tile entire sequence
final_indices = ov_opset.tile(
dim_flat, ov_opset.unsqueeze(tile_factor, zero_scalar).output(0)
).output(0)

all_indices.append(final_indices)

return all_indices
Comment on lines +838 to +914

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The create_meshgrid_indices function is complex, which can make it difficult to read and maintain. Consider refactoring this function by extracting parts of the logic into smaller helper functions. For example, the logic for calculating repeat_factor and tile_factor, and the meshgrid transformation steps (repeating, flattening, tiling) could each be in their own helper functions. This would make the main loop in create_meshgrid_indices more declarative and easier to understand.


processed_start_indices = [process_index(idx) for idx in start_indices]

updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)
rank = updates_tensor.get_partial_shape().rank.get_length()
num_updates = ov_opset.reduce_prod(
updates_shape, zero_tensor, keep_dims=False
).output(0)
Comment on lines +918 to +922

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation assumes that the rank of the updates tensor is statically known. The call to .get_length() on this line will raise an exception if the rank is dynamic. Keras layers and backends should ideally support dynamic shapes, including dynamic ranks. The logic within create_meshgrid_indices and process_index also relies on a static rank. Please consider how to support dynamic ranks, which might involve using OpenVINO's control flow operations (like ov.opset.loop) instead of a Python for loop over the rank. If supporting dynamic rank is out of scope for this PR, this limitation should be documented in a docstring for the function.


all_indices = create_meshgrid_indices()

# Stack and reshape indices
indices_stack = ov_opset.concat(all_indices, axis=0).output(0)
indices_matrix = ov_opset.reshape(
indices_stack,
ov_opset.concat(
[
ov_opset.constant([rank], Type.i32),
ov_opset.unsqueeze(num_updates, zero_scalar).output(0),
],
axis=0,
).output(0),
special_zero=False,
).output(0)

# Transpose to [num_updates, rank]
relative_indices = ov_opset.transpose(
indices_matrix, ov_opset.constant([1, 0], Type.i32)
).output(0)

# Add start indices offset
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0)
start_broadcast = ov_opset.tile(
ov_opset.reshape(
start_tensor,
ov_opset.constant([1, rank], Type.i32),
special_zero=False,
).output(0),
ov_opset.concat(
[
ov_opset.unsqueeze(num_updates, zero_scalar).output(0),
one_tensor,
],
axis=0,
).output(0),
).output(0)

absolute_indices = ov_opset.add(relative_indices, start_broadcast).output(0)

# Flatten updates to a 1D tensor for scatter_nd_update compatibility.
updates_flat = ov_opset.reshape(
updates_tensor,
ov_opset.unsqueeze(num_updates, zero_scalar).output(0),
special_zero=False,
).output(0)

# Scatter updates into input tensor at target indices.
result = ov_opset.scatter_nd_update(
inputs, absolute_indices, updates_flat
).output(0)
return OpenVINOKerasTensor(result)


def while_loop(
cond,
Expand Down
3 changes: 0 additions & 3 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,14 @@ CoreOpsCallsTests::test_map_basic_call
CoreOpsCallsTests::test_scan_basic_call
CoreOpsCallsTests::test_scatter_basic_call
CoreOpsCallsTests::test_scatter_update_basic_call
CoreOpsCallsTests::test_slice_update_basic_call
CoreOpsCallsTests::test_switch_basic_call
CoreOpsCallsTests::test_unstack_basic_functionality
CoreOpsCorrectnessTest::test_associative_scan
CoreOpsCorrectnessTest::test_cond
CoreOpsCorrectnessTest::test_dynamic_slice
CoreOpsCorrectnessTest::test_fori_loop
CoreOpsCorrectnessTest::test_map
CoreOpsCorrectnessTest::test_scan
CoreOpsCorrectnessTest::test_scatter
CoreOpsCorrectnessTest::test_slice_update
CoreOpsCorrectnessTest::test_switch
CoreOpsCorrectnessTest::test_unstack
CoreOpsCorrectnessTest::test_vectorized_map
Expand Down