@@ -322,11 +322,11 @@ def forward(self, input: Tensor) -> Tensor:
322
322
class _ConvTransposeNd (_ConvNd ):
323
323
def __init__ (self , in_channels , out_channels , kernel_size , stride ,
324
324
padding , dilation , transposed , output_padding ,
325
- groups , bias , padding_mode , dtype = None ) -> None :
325
+ groups , bias , padding_mode , dtype = None , device = None ) -> None :
326
326
if padding_mode != 'zeros' :
327
327
raise ValueError (f'Only "zeros" padding mode is supported for { self .__class__ .__name__ } ' )
328
328
329
- factory_kwargs = {'dtype' : dtype }
329
+ factory_kwargs = {'dtype' : dtype , 'device' : device }
330
330
super ().__init__ (
331
331
in_channels , out_channels , kernel_size , stride ,
332
332
padding , dilation , transposed , output_padding ,
@@ -426,62 +426,71 @@ class ConvTranspose1d(_ConvTransposeNd):
426
426
bias (Tensor): the learnable bias of the module of shape (out_channels)
427
427
"""
428
428
429
- def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
430
- padding = 0 , output_padding = 0 , groups = 1 , bias = True , dilation = 1 , padding_mode : str = 'zeros' ):
429
+ def __init__ (
430
+ self ,
431
+ in_channels : int ,
432
+ out_channels : int ,
433
+ kernel_size : _size_1_t ,
434
+ stride : _size_1_t = 1 ,
435
+ padding : _size_1_t = 0 ,
436
+ output_padding : _size_1_t = 0 ,
437
+ groups : int = 1 ,
438
+ bias : bool = True ,
439
+ dilation : _size_1_t = 1 ,
440
+ padding_mode : str = "zeros" ,
441
+ device = None ,
442
+ dtype = None ,
443
+ ) -> None :
444
+ factory_kwargs = {"device" : device , "dtype" : dtype }
431
445
kernel_size = _single (kernel_size )
432
446
stride = _single (stride )
433
447
padding = _single (padding )
434
448
dilation = _single (dilation )
435
449
output_padding = _single (output_padding )
436
- super (ConvTranspose1d , self ).__init__ (
437
- in_channels , out_channels , kernel_size , stride , padding , dilation ,
438
- True , output_padding , groups , bias , padding_mode )
439
-
440
- pad_mode = 'pad'
441
- pad = padding
442
- if isinstance (padding , tuple ):
443
- pad = (0 , 0 , padding [0 ], padding [0 ])
444
- elif isinstance (padding , int ):
445
- pad = (0 , 0 ) + (padding ,) * 2
446
- if not isinstance (padding , (int , tuple )):
447
- pad_mode = padding
448
- pad = (0 ,) * 4
449
-
450
- # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
451
- self .conv2d_transpose = mops .Conv2DTranspose (out_channel = self .out_channels ,
452
- kernel_size = (1 ,) + self .kernel_size ,
453
- mode = 1 ,
454
- pad_mode = pad_mode ,
455
- pad = pad ,
456
- stride = (1 ,) + self .stride ,
457
- dilation = (1 ,) + self .dilation ,
458
- group = self .groups )
459
- self .h_add = _deconv_output_length (pad_mode , 1 , 1 , 1 , pad [0 ] + pad [1 ])
460
- self .w_add = _deconv_output_length (pad_mode , kernel_size [0 ], stride [0 ], dilation [0 ], pad [2 ] + pad [3 ])
461
-
462
- def forward (self , input , output_size = None ):
463
- if self .padding_mode != 'zeros' :
464
- raise ValueError ('Only `zeros` padding mode is supported for ConvTranspose2d' )
450
+ super ().__init__ (
451
+ in_channels ,
452
+ out_channels ,
453
+ kernel_size ,
454
+ stride ,
455
+ padding ,
456
+ dilation ,
457
+ True ,
458
+ output_padding ,
459
+ groups ,
460
+ bias ,
461
+ padding_mode ,
462
+ ** factory_kwargs ,
463
+ )
464
+
465
+ def forward (self , input : Tensor , output_size : Optional [list [int ]] = None ) -> Tensor :
466
+ if self .padding_mode != "zeros" :
467
+ raise ValueError (
468
+ "Only `zeros` padding mode is supported for ConvTranspose1d"
469
+ )
465
470
466
471
assert isinstance (self .padding , tuple )
467
472
# One cannot replace List by Tuple or Sequence in "_output_padding" because
468
473
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
469
474
num_spatial_dims = 1
470
475
output_padding = self ._output_padding (
471
- input , output_size , self .stride , self .padding , self .kernel_size , # type: ignore[arg-type]
472
- num_spatial_dims , self .dilation ) # type: ignore[arg-type]
473
- input = mops .expand_dims (input , 2 )
474
- n , _ , h , w = input .shape
475
- conv2d_trans_ret = self .conv2d_transpose (input , self .weight .expand_dims (2 ),
476
- (n , self .out_channels ,
477
- h + self .h_add ,
478
- w * self .stride [0 ] + self .w_add ))
479
- if self .bias is not None :
480
- conv2d_trans_ret = mops .bias_add (conv2d_trans_ret , self .bias )
481
-
482
- conv2d_trans_ret = conv2d_trans_ret .squeeze (2 )
483
- conv2d_trans_ret = ops .pad (conv2d_trans_ret , (0 ,) + output_padding , value = 0. )
484
- return conv2d_trans_ret
476
+ input ,
477
+ output_size ,
478
+ self .stride , # type: ignore[arg-type]
479
+ self .padding , # type: ignore[arg-type]
480
+ self .kernel_size , # type: ignore[arg-type]
481
+ num_spatial_dims ,
482
+ self .dilation , # type: ignore[arg-type]
483
+ )
484
+ return F .conv_transpose1d (
485
+ input ,
486
+ self .weight ,
487
+ self .bias ,
488
+ self .stride ,
489
+ self .padding ,
490
+ output_padding ,
491
+ self .groups ,
492
+ self .dilation ,
493
+ )
485
494
486
495
487
496
def _deconv_output_length (pad_mode , filter_size , stride_size , dilation_size , padding ):
@@ -582,66 +591,80 @@ class ConvTranspose2d(_ConvTransposeNd):
582
591
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
583
592
"""
584
593
585
- def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
586
- padding = 0 , output_padding = 0 , groups = 1 , bias = True , dilation = 1 ,
587
- padding_mode = 'zeros' , dtype = None ):
588
- factory_kwargs = {'dtype' : dtype }
594
+ def __init__ (
595
+ self ,
596
+ in_channels : int ,
597
+ out_channels : int ,
598
+ kernel_size : _size_2_t ,
599
+ stride : _size_2_t = 1 ,
600
+ padding : _size_2_t = 0 ,
601
+ output_padding : _size_2_t = 0 ,
602
+ groups : int = 1 ,
603
+ bias : bool = True ,
604
+ dilation : _size_2_t = 1 ,
605
+ padding_mode : str = "zeros" ,
606
+ device = None ,
607
+ dtype = None ,
608
+ ) -> None :
609
+ factory_kwargs = {"device" : device , "dtype" : dtype }
589
610
kernel_size = _pair (kernel_size )
590
611
stride = _pair (stride )
591
612
padding = _pair (padding )
592
613
dilation = _pair (dilation )
593
614
output_padding = _pair (output_padding )
594
615
super ().__init__ (
595
- in_channels , out_channels , kernel_size , stride , padding , dilation ,
596
- True , output_padding , groups , bias , padding_mode , ** factory_kwargs )
597
-
598
- pad_mode = 'pad'
599
- pad = padding
600
- if isinstance (padding , tuple ):
601
- pad = (padding [0 ], padding [0 ], padding [1 ], padding [1 ])
602
- elif isinstance (padding , int ):
603
- pad = (padding ,) * 4
604
- if not isinstance (padding , (int , tuple )):
605
- pad_mode = padding
606
- pad = (0 ,) * 4
607
-
608
- # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
609
- self .conv2d_transpose = mops .Conv2DTranspose (out_channel = in_channels ,
610
- kernel_size = kernel_size ,
611
- mode = 1 ,
612
- pad_mode = pad_mode ,
613
- pad = pad ,
614
- stride = stride ,
615
- dilation = dilation ,
616
- group = groups )
617
-
618
- self .h_add = _deconv_output_length (pad_mode , kernel_size [0 ], stride [0 ], dilation [0 ], pad [0 ] + pad [1 ])
619
- self .w_add = _deconv_output_length (pad_mode , kernel_size [1 ], stride [1 ], dilation [1 ], pad [2 ] + pad [3 ])
620
-
621
- def forward (self , input , output_size = None ):
622
- if self .padding_mode != 'zeros' :
623
- raise ValueError ('Only `zeros` padding mode is supported for ConvTranspose2d' )
616
+ in_channels ,
617
+ out_channels ,
618
+ kernel_size ,
619
+ stride ,
620
+ padding ,
621
+ dilation ,
622
+ True ,
623
+ output_padding ,
624
+ groups ,
625
+ bias ,
626
+ padding_mode ,
627
+ ** factory_kwargs ,
628
+ )
629
+
630
+ def forward (self , input : Tensor , output_size : Optional [list [int ]] = None ) -> Tensor :
631
+ """
632
+ Performs the forward pass.
633
+
634
+ Attributes:
635
+ input (Tensor): The input tensor.
636
+ output_size (list[int], optional): A list of integers representing
637
+ the size of the output tensor. Default is None.
638
+ """
639
+ if self .padding_mode != "zeros" :
640
+ raise ValueError (
641
+ "Only `zeros` padding mode is supported for ConvTranspose2d"
642
+ )
624
643
625
644
assert isinstance (self .padding , tuple )
626
645
# One cannot replace List by Tuple or Sequence in "_output_padding" because
627
646
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
628
647
num_spatial_dims = 2
629
648
output_padding = self ._output_padding (
630
- input , output_size , self .stride , self .padding , self .kernel_size , # type: ignore[arg-type]
631
- num_spatial_dims , self .dilation ) # type: ignore[arg-type]
632
-
633
- n , _ , h , w = input .shape
634
- conv2d_trans_ret = self .conv2d_transpose (input , self .weight ,
635
- (n , self .out_channels ,
636
- h * self .stride [0 ] + self .h_add ,
637
- w * self .stride [1 ] + self .w_add ))
638
- if self .bias is not None :
639
- conv2d_trans_ret = mops .bias_add (conv2d_trans_ret , self .bias )
640
-
641
- conv2d_trans_ret = ops .pad (conv2d_trans_ret , output_padding , value = 0. )
642
-
643
- return conv2d_trans_ret
649
+ input ,
650
+ output_size ,
651
+ self .stride , # type: ignore[arg-type]
652
+ self .padding , # type: ignore[arg-type]
653
+ self .kernel_size , # type: ignore[arg-type]
654
+ num_spatial_dims ,
655
+ self .dilation , # type: ignore[arg-type]
656
+ )
644
657
658
+ return F .conv_transpose2d (
659
+ input ,
660
+ self .weight ,
661
+ self .bias ,
662
+ self .stride ,
663
+ self .padding ,
664
+ output_padding ,
665
+ self .groups ,
666
+ self .dilation ,
667
+ )
645
668
646
669
# class ConvTranspose3d(_ConvTransposeNd):
647
670
# r"""Applies a 3D transposed convolution operator over an input image composed of several input
0 commit comments