Skip to content

Commit c825e0e

Browse files
[OpenVIno backend] support categorical
1 parent be9b002 commit c825e0e

File tree

3 files changed

+112
-4
lines changed

3 files changed

+112
-4
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,58 @@ CoreOpsDtypeTest::test_convert_to_tensor12
188188
CoreOpsDtypeTest::test_convert_to_tensor14
189189
CoreOpsDtypeTest::test_convert_to_tensor25
190190
CoreOpsDtypeTest::test_convert_to_tensor37
191+
RandomCorrectnessTest::test_beta0
192+
RandomCorrectnessTest::test_beta1
193+
RandomCorrectnessTest::test_beta2
194+
RandomCorrectnessTest::test_binomial0
195+
RandomCorrectnessTest::test_binomial1
196+
RandomCorrectnessTest::test_binomial2
197+
RandomCorrectnessTest::test_dropout
198+
RandomCorrectnessTest::test_dropout_noise_shape
199+
RandomCorrectnessTest::test_gamma0
200+
RandomCorrectnessTest::test_gamma1
201+
RandomCorrectnessTest::test_gamma2
202+
RandomCorrectnessTest::test_randint0
203+
RandomCorrectnessTest::test_randint1
204+
RandomCorrectnessTest::test_randint2
205+
RandomCorrectnessTest::test_randint3
206+
RandomCorrectnessTest::test_randint4
207+
RandomCorrectnessTest::test_shuffle
208+
RandomCorrectnessTest::test_truncated_normal0
209+
RandomCorrectnessTest::test_truncated_normal1
210+
RandomCorrectnessTest::test_truncated_normal2
211+
RandomCorrectnessTest::test_truncated_normal3
212+
RandomCorrectnessTest::test_truncated_normal4
213+
RandomCorrectnessTest::test_truncated_normal5
214+
RandomCorrectnessTest::test_uniform0
215+
RandomCorrectnessTest::test_uniform1
216+
RandomCorrectnessTest::test_uniform2
217+
RandomCorrectnessTest::test_uniform3
218+
RandomCorrectnessTest::test_uniform4
219+
RandomBehaviorTest::test_beta_tf_data_compatibility
220+
RandomDTypeTest::test_beta_bfloat16
221+
RandomDTypeTest::test_beta_float16
222+
RandomDTypeTest::test_beta_float32
223+
RandomDTypeTest::test_beta_float64
224+
RandomDTypeTest::test_binomial_bfloat16
225+
RandomDTypeTest::test_binomial_float16
226+
RandomDTypeTest::test_binomial_float32
227+
RandomDTypeTest::test_binomial_float64
228+
RandomDTypeTest::test_dropout_bfloat16
229+
RandomDTypeTest::test_dropout_float16
230+
RandomDTypeTest::test_dropout_float32
231+
RandomDTypeTest::test_dropout_float64
232+
RandomDTypeTest::test_gamma_bfloat16
233+
RandomDTypeTest::test_gamma_float16
234+
RandomDTypeTest::test_gamma_float32
235+
RandomDTypeTest::test_gamma_float64
236+
RandomDTypeTest::test_normal_bfloat16
237+
RandomDTypeTest::test_randint_int16
238+
RandomDTypeTest::test_randint_int32
239+
RandomDTypeTest::test_randint_int64
240+
RandomDTypeTest::test_randint_int8
241+
RandomDTypeTest::test_randint_uint16
242+
RandomDTypeTest::test_randint_uint32
243+
RandomDTypeTest::test_randint_uint8
244+
RandomDTypeTest::test_truncated_normal_bfloat16
245+
RandomDTypeTest::test_uniform_bfloat16

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ keras/src/ops/math_test.py
3333
keras/src/ops/nn_test.py
3434
keras/src/optimizers
3535
keras/src/quantizers
36-
keras/src/random
36+
keras/src/random/seed_generator_test.py
3737
keras/src/regularizers
3838
keras/src/saving
3939
keras/src/trainers

keras/src/backend/openvino/random.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.src.backend.openvino.core import OPENVINO_DTYPES
77
from keras.src.backend.openvino.core import OpenVINOKerasTensor
88
from keras.src.backend.openvino.core import convert_to_numpy
9+
from keras.src.backend.openvino.core import get_ov_output
910
from keras.src.random.seed_generator import SeedGenerator
1011
from keras.src.random.seed_generator import draw_seed
1112
from keras.src.random.seed_generator import make_default_seed
@@ -39,9 +40,61 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
3940

4041

4142
def categorical(logits, num_samples, dtype="int64", seed=None):
42-
raise NotImplementedError(
43-
"`categorical` is not supported with openvino backend"
44-
)
43+
dtype = dtype or "int64"
44+
ov_dtype = OPENVINO_DTYPES[dtype]
45+
logits = get_ov_output(logits)
46+
47+
zero_const = ov_opset.constant(0, Type.i32).output(0)
48+
one_const = ov_opset.constant(1, Type.i32).output(0)
49+
neg_one_const = ov_opset.constant(-1, Type.i32).output(0)
50+
51+
# Compute probabilities and cumulative sum
52+
probs = ov_opset.softmax(logits, axis=-1).output(0)
53+
cumsum_probs = ov_opset.cumsum(probs, neg_one_const).output(0)
54+
55+
# Get shape and compute batch dimensions efficiently
56+
logits_shape = ov_opset.shape_of(logits, Type.i32).output(0)
57+
rank = ov_opset.shape_of(logits_shape, Type.i32).output(0)
58+
rank_scalar = ov_opset.squeeze(rank, zero_const).output(0)
59+
rank_minus_1 = ov_opset.subtract(rank_scalar, one_const).output(0)
60+
61+
# Extract batch shape (all dimensions except last)
62+
batch_indices = ov_opset.range(
63+
zero_const, rank_minus_1, one_const, output_type=Type.i32
64+
).output(0)
65+
batch_shape = ov_opset.gather(logits_shape, batch_indices, axis=0).output(0)
66+
67+
# Create final shape [batch_dims..., num_samples]
68+
num_samples_const = ov_opset.constant([num_samples], Type.i32).output(0)
69+
final_shape = ov_opset.concat(
70+
[batch_shape, num_samples_const], axis=0
71+
).output(0)
72+
73+
seed_tensor = draw_seed(seed)
74+
if isinstance(seed_tensor, OpenVINOKerasTensor):
75+
seed1, seed2 = convert_to_numpy(seed_tensor)
76+
else:
77+
seed1, seed2 = seed_tensor.data
78+
79+
probs_dtype = probs.get_element_type()
80+
zero_float = ov_opset.constant(0.0, probs_dtype).output(0)
81+
one_float = ov_opset.constant(1.0, probs_dtype).output(0)
82+
83+
rand = ov_opset.random_uniform(
84+
final_shape, zero_float, one_float, probs_dtype, seed1, seed2
85+
).output(0)
86+
87+
rand_unsqueezed = ov_opset.unsqueeze(rand, neg_one_const).output(0)
88+
cumsum_unsqueezed = ov_opset.unsqueeze(cumsum_probs, one_const).output(0)
89+
90+
# Count how many cumulative probabilities each random number exceeds
91+
greater = ov_opset.greater(rand_unsqueezed, cumsum_unsqueezed).output(0)
92+
samples = ov_opset.reduce_sum(
93+
ov_opset.convert(greater, Type.i32).output(0), neg_one_const
94+
).output(0)
95+
96+
result = ov_opset.convert(samples, ov_dtype).output(0)
97+
return OpenVINOKerasTensor(result)
4598

4699

47100
def randint(shape, minval, maxval, dtype="int32", seed=None):

0 commit comments

Comments
 (0)