@@ -1590,9 +1590,56 @@ def pixel_shuffle(input, upscale_factor):
1590
1590
def pixel_unshuffle (input , downscale_factor ):
1591
1591
return ops .pixel_unshuffle (input , downscale_factor )
1592
1592
1593
+ def getWH (input ):
1594
+ """Get [W, H] tensor from input"""
1595
+ H , W = input .size ()[- 2 :]
1596
+ return core .tensor ([[W , H ]], dtype = core .float32 , device = input .device )
1597
+
1598
+ def center_of (input ):
1599
+ """return [(W-1)/2, (H-1)/2] tensor of input img"""
1600
+ if input .dim () == 4 :
1601
+ H , W = input .size ()[- 2 :]
1602
+ shape = [[W , H ]]
1603
+ else :
1604
+ D , H , W = input .size ()[- 3 :]
1605
+ shape = [[W , H , D ]]
1606
+ return core .tensor (shape , dtype = core .float32 , device = input .device ).sub_ (1 ).div_ (2 )
1607
+
1608
+ def u (s , a : float = - 0.75 ):
1609
+ s2 , s3 = s ** 2 , s ** 3
1610
+ l1 = (a + 2 )* s3 - (a + 3 )* s2 + 1
1611
+ l2 = a * s3 - (5 * a )* s2 + (8 * a )* s - 4 * a
1612
+ return l1 .where (s <= 1 , l2 )
1613
+
1614
+ def bicubic_grid_sample (input , grid , padding_mode : str = 'zeros' , align_corners : bool = False ):
1615
+ """bicubic_grid_sample"""
1616
+ kernel_size = 4
1617
+ if not align_corners :
1618
+ grid = grid * getWH (input ) / getWH (input ).sub_ (1 )
1619
+ center = center_of (input )
1620
+ abs_loc = ((grid + 1 ) * center ).unsqueeze (- 1 )
1621
+
1622
+ locs = abs_loc .floor () + core .tensor ([- 1 , 0 , 1 , 2 ], device = grid .device )
1623
+
1624
+ loc_w , loc_h = locs .detach ().flatten (0 , 2 ).unbind (dim = - 2 )
1625
+ loc_w = loc_w .reshape (- 1 , 1 , kernel_size ).expand (- 1 , kernel_size , - 1 )
1626
+ loc_h = loc_h .reshape (- 1 , kernel_size , 1 ).expand (- 1 , - 1 , kernel_size )
1627
+ loc_grid = core .stack ([loc_w , loc_h ], dim = - 1 )
1628
+ loc_grid = loc_grid .view (grid .size (0 ), - 1 , 1 , 2 )/ center - 1
1629
+
1630
+ selected = grid_sample (input , loc_grid .detach (), mode = 'nearest' ,
1631
+ padding_mode = padding_mode , align_corners = True )
1632
+ patch = selected .view (input .size ()[:2 ]+ grid .size ()[1 :3 ]+ (kernel_size ,)* 2 )
1633
+
1634
+ mat_r , mat_l = u (core .abs (abs_loc - locs .detach ())).unbind (dim = - 2 )
1635
+ output = core .einsum ('bhwl,bchwlr,bhwr->bchw' , mat_l , patch , mat_r )
1636
+ return output
1637
+
1593
1638
def grid_sample (input , grid , mode = 'bilinear' , padding_mode = 'zeros' , align_corners = False ):
1594
1639
align_corners = False if align_corners is None else align_corners
1595
1640
if input .ndim == 4 :
1641
+ if mode == 'bicubic' :
1642
+ return bicubic_grid_sample (input , grid , padding_mode , align_corners )
1596
1643
return execute ('grid_sampler_2d' , input , grid , mode , padding_mode , align_corners )
1597
1644
return execute ('grid_sampler_3d' , input , grid , mode , padding_mode , align_corners )
1598
1645
0 commit comments