@@ -1442,10 +1442,119 @@ def take(x, indices, axis=None):
1442
1442
1443
1443
1444
1444
def take_along_axis (x , indices , axis = None ):
1445
- raise NotImplementedError (
1446
- "`take_along_axis` is not supported with openvino backend"
1445
+ x = get_ov_output (x )
1446
+ indices = get_ov_output (indices )
1447
+
1448
+ if axis is None :
1449
+ target_shape = ov_opset .constant ([- 1 ], dtype = Type .i32 ).output (0 )
1450
+ x_flat = ov_opset .reshape (x , target_shape , False ).output (0 )
1451
+ indices_flat = ov_opset .reshape (indices , target_shape , False ).output (0 )
1452
+ result = ov_opset .gather_elements (x_flat , indices_flat , 0 ).output (0 )
1453
+ return OpenVINOKerasTensor (result )
1454
+
1455
+ x_rank = len (x .get_partial_shape ())
1456
+ if axis < 0 :
1457
+ axis += x_rank
1458
+
1459
+ x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1460
+ indices_shape = ov_opset .shape_of (indices , Type .i32 ).output (0 )
1461
+
1462
+ # fix negative indices by adding dimension size
1463
+ axis_index = ov_opset .constant ([axis ], dtype = Type .i32 ).output (0 )
1464
+ zero_const = ov_opset .constant (0 , dtype = Type .i32 ).output (0 )
1465
+ dim_size = ov_opset .gather (x_shape , axis_index , zero_const ).output (0 )
1466
+ dim_size = ov_opset .squeeze (dim_size , zero_const ).output (0 )
1467
+
1468
+ zero_scalar = ov_opset .constant (0 , indices .get_element_type ()).output (0 )
1469
+ is_neg = ov_opset .less (indices , zero_scalar ).output (0 )
1470
+ dim_size_cast = ov_opset .convert (
1471
+ dim_size , indices .get_element_type ()
1472
+ ).output (0 )
1473
+ adjusted_indices = ov_opset .add (indices , dim_size_cast ).output (0 )
1474
+ indices = ov_opset .select (is_neg , adjusted_indices , indices ).output (0 )
1475
+
1476
+ indices = ov_opset .convert (indices , Type .i32 ).output (0 )
1477
+
1478
+ one_const = ov_opset .constant (1 , dtype = Type .i32 ).output (0 )
1479
+
1480
+ # Create modified shapes with axis dimension set to 1
1481
+ x_shape_modified = []
1482
+ indices_shape_modified = []
1483
+
1484
+ for i in range (x_rank ):
1485
+ dim_index = ov_opset .constant ([i ], dtype = Type .i32 ).output (0 )
1486
+ if i == axis :
1487
+ x_shape_modified .append (
1488
+ ov_opset .unsqueeze (one_const , zero_const ).output (0 )
1489
+ )
1490
+ indices_shape_modified .append (
1491
+ ov_opset .unsqueeze (one_const , zero_const ).output (0 )
1492
+ )
1493
+ else :
1494
+ x_dim = ov_opset .gather (x_shape , dim_index , zero_const ).output (0 )
1495
+ indices_dim = ov_opset .gather (
1496
+ indices_shape , dim_index , zero_const
1497
+ ).output (0 )
1498
+ x_shape_modified .append (x_dim )
1499
+ indices_shape_modified .append (indices_dim )
1500
+
1501
+ x_shape_mod = ov_opset .concat (x_shape_modified , axis = 0 ).output (0 )
1502
+ indices_shape_mod = ov_opset .concat (indices_shape_modified , axis = 0 ).output (
1503
+ 0
1447
1504
)
1448
1505
1506
+ # Compute broadcast shape (maximum of each dimension)
1507
+ broadcast_shape_parts = []
1508
+ for i in range (x_rank ):
1509
+ dim_index = ov_opset .constant ([i ], dtype = Type .i32 ).output (0 )
1510
+ x_dim = ov_opset .gather (x_shape_mod , dim_index , zero_const ).output (0 )
1511
+ indices_dim = ov_opset .gather (
1512
+ indices_shape_mod , dim_index , zero_const
1513
+ ).output (0 )
1514
+ max_dim = ov_opset .maximum (x_dim , indices_dim ).output (0 )
1515
+ broadcast_shape_parts .append (max_dim )
1516
+
1517
+ broadcast_shape = ov_opset .concat (broadcast_shape_parts , axis = 0 ).output (0 )
1518
+
1519
+ # Create target shapes: broadcast shape but with original axis dimensions
1520
+ x_target_shape_parts = []
1521
+ indices_target_shape_parts = []
1522
+
1523
+ for i in range (x_rank ):
1524
+ dim_index = ov_opset .constant ([i ], dtype = Type .i32 ).output (0 )
1525
+ if i == axis :
1526
+ x_orig_dim = ov_opset .gather (x_shape , dim_index , zero_const ).output (
1527
+ 0
1528
+ )
1529
+ indices_orig_dim = ov_opset .gather (
1530
+ indices_shape , dim_index , zero_const
1531
+ ).output (0 )
1532
+ x_target_shape_parts .append (x_orig_dim )
1533
+ indices_target_shape_parts .append (indices_orig_dim )
1534
+ else :
1535
+ broadcast_dim = ov_opset .gather (
1536
+ broadcast_shape , dim_index , zero_const
1537
+ ).output (0 )
1538
+ x_target_shape_parts .append (broadcast_dim )
1539
+ indices_target_shape_parts .append (broadcast_dim )
1540
+
1541
+ x_target_shape = ov_opset .concat (x_target_shape_parts , axis = 0 ).output (0 )
1542
+ indices_target_shape = ov_opset .concat (
1543
+ indices_target_shape_parts , axis = 0
1544
+ ).output (0 )
1545
+
1546
+ # Broadcast to target shapes
1547
+ x_broadcasted = ov_opset .broadcast (x , x_target_shape ).output (0 )
1548
+ indices_broadcasted = ov_opset .broadcast (
1549
+ indices , indices_target_shape
1550
+ ).output (0 )
1551
+
1552
+ # Use gather_elements for element-wise selection
1553
+ result = ov_opset .gather_elements (
1554
+ x_broadcasted , indices_broadcasted , axis
1555
+ ).output (0 )
1556
+ return OpenVINOKerasTensor (result )
1557
+
1449
1558
1450
1559
def tan (x ):
1451
1560
x = get_ov_output (x )
0 commit comments