Skip to content

[OpenVINO backend] support categorical #21437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,61 @@ MathOpsCorrectnessTest::test_stft3
MathOpsCorrectnessTest::test_stft4
MathOpsCorrectnessTest::test_stft5
MathOpsCorrectnessTest::test_stft6
RandomCorrectnessTest::test_beta0
RandomCorrectnessTest::test_beta1
RandomCorrectnessTest::test_beta2
RandomCorrectnessTest::test_binomial0
RandomCorrectnessTest::test_binomial1
RandomCorrectnessTest::test_binomial2
RandomCorrectnessTest::test_dropout
RandomCorrectnessTest::test_dropout_noise_shape
RandomCorrectnessTest::test_gamma0
RandomCorrectnessTest::test_gamma1
RandomCorrectnessTest::test_gamma2
RandomCorrectnessTest::test_randint0
RandomCorrectnessTest::test_randint1
RandomCorrectnessTest::test_randint2
RandomCorrectnessTest::test_randint3
RandomCorrectnessTest::test_randint4
RandomCorrectnessTest::test_shuffle
RandomCorrectnessTest::test_truncated_normal0
RandomCorrectnessTest::test_truncated_normal1
RandomCorrectnessTest::test_truncated_normal2
RandomCorrectnessTest::test_truncated_normal3
RandomCorrectnessTest::test_truncated_normal4
RandomCorrectnessTest::test_truncated_normal5
RandomCorrectnessTest::test_uniform0
RandomCorrectnessTest::test_uniform1
RandomCorrectnessTest::test_uniform2
RandomCorrectnessTest::test_uniform3
RandomCorrectnessTest::test_uniform4
RandomBehaviorTest::test_beta_tf_data_compatibility
RandomDTypeTest::test_beta_bfloat16
RandomDTypeTest::test_beta_float16
RandomDTypeTest::test_beta_float32
RandomDTypeTest::test_beta_float64
RandomDTypeTest::test_binomial_bfloat16
RandomDTypeTest::test_binomial_float16
RandomDTypeTest::test_binomial_float32
RandomDTypeTest::test_binomial_float64
RandomDTypeTest::test_dropout_bfloat16
RandomDTypeTest::test_dropout_float16
RandomDTypeTest::test_dropout_float32
RandomDTypeTest::test_dropout_float64
RandomDTypeTest::test_gamma_bfloat16
RandomDTypeTest::test_gamma_float16
RandomDTypeTest::test_gamma_float32
RandomDTypeTest::test_gamma_float64
RandomDTypeTest::test_normal_bfloat16
RandomDTypeTest::test_randint_int16
RandomDTypeTest::test_randint_int32
RandomDTypeTest::test_randint_int64
RandomDTypeTest::test_randint_int8
RandomDTypeTest::test_randint_uint16
RandomDTypeTest::test_randint_uint32
RandomDTypeTest::test_randint_uint8
RandomDTypeTest::test_truncated_normal_bfloat16
RandomDTypeTest::test_uniform_bfloat16
SegmentSumTest::test_segment_sum_call
SegmentMaxTest::test_segment_max_call
TestMathErrors::test_invalid_fft_length
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/openvino/excluded_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ keras/src/ops/linalg_test.py
keras/src/ops/nn_test.py
keras/src/optimizers
keras/src/quantizers
keras/src/random
keras/src/random/seed_generator_test.py
keras/src/regularizers
keras/src/saving
keras/src/trainers
Expand Down
59 changes: 56 additions & 3 deletions keras/src/backend/openvino/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.src.backend.openvino.core import OPENVINO_DTYPES
from keras.src.backend.openvino.core import OpenVINOKerasTensor
from keras.src.backend.openvino.core import convert_to_numpy
from keras.src.backend.openvino.core import get_ov_output
from keras.src.random.seed_generator import SeedGenerator
from keras.src.random.seed_generator import draw_seed
from keras.src.random.seed_generator import make_default_seed
Expand Down Expand Up @@ -39,9 +40,61 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):


def categorical(logits, num_samples, dtype="int64", seed=None):
raise NotImplementedError(
"`categorical` is not supported with openvino backend"
)
dtype = dtype or "int64"
ov_dtype = OPENVINO_DTYPES[dtype]
logits = get_ov_output(logits)

zero_const = ov_opset.constant(0, Type.i32).output(0)
one_const = ov_opset.constant(1, Type.i32).output(0)
neg_one_const = ov_opset.constant(-1, Type.i32).output(0)

# Compute probabilities and cumulative sum
probs = ov_opset.softmax(logits, axis=-1).output(0)
cumsum_probs = ov_opset.cumsum(probs, neg_one_const).output(0)

# Get shape and compute batch dimensions
logits_shape = ov_opset.shape_of(logits, Type.i32).output(0)
rank = ov_opset.shape_of(logits_shape, Type.i32).output(0)
rank_scalar = ov_opset.squeeze(rank, zero_const).output(0)
rank_minus_1 = ov_opset.subtract(rank_scalar, one_const).output(0)

# Extract batch shape (all dimensions except last)
batch_indices = ov_opset.range(
zero_const, rank_minus_1, one_const, output_type=Type.i32
).output(0)
batch_shape = ov_opset.gather(logits_shape, batch_indices, axis=0).output(0)

# Create final shape [batch_dims..., num_samples]
num_samples_const = ov_opset.constant([num_samples], Type.i32).output(0)
final_shape = ov_opset.concat(
[batch_shape, num_samples_const], axis=0
).output(0)

seed_tensor = draw_seed(seed)
if isinstance(seed_tensor, OpenVINOKerasTensor):
seed1, seed2 = convert_to_numpy(seed_tensor)
else:
seed1, seed2 = seed_tensor.data

probs_dtype = probs.get_element_type()
zero_float = ov_opset.constant(0.0, probs_dtype).output(0)
one_float = ov_opset.constant(1.0, probs_dtype).output(0)

rand = ov_opset.random_uniform(
final_shape, zero_float, one_float, probs_dtype, seed1, seed2
).output(0)

rand_unsqueezed = ov_opset.unsqueeze(rand, neg_one_const).output(0)
cumsum_unsqueezed = ov_opset.unsqueeze(cumsum_probs, one_const).output(0)

# Count how many cumulative probabilities each random number exceeds
greater = ov_opset.greater(rand_unsqueezed, cumsum_unsqueezed).output(0)
samples = ov_opset.reduce_sum(
ov_opset.convert(greater, Type.i32).output(0), neg_one_const
).output(0)

result = ov_opset.convert(samples, ov_dtype).output(0)
return OpenVINOKerasTensor(result)


def randint(shape, minval, maxval, dtype="int32", seed=None):
Expand Down