16
16
from typing import Tuple
17
17
from torch import Tensor
18
18
19
- from .cextension import COMPILED_WITH_CUDA , lib
19
+ from .cextension import COMPILED_WITH_CUDA , lib , HIP_ENVIRONMENT
20
20
21
21
# Remark: for AMD GPU we need to disable blocksize == 64
22
22
@@ -458,7 +458,11 @@ def get_transform_buffer(
458
458
state = (shape [::- 1 ], to_order )
459
459
460
460
if to_order == "row" or to_order == "col" :
461
- return init_func (shape , dtype = dtype , device = device ), state
461
+ if HIP_ENVIRONMENT and to_order == "col" :
462
+ # row to col transformation transposes output shape, so change buffer allocation accordingly
463
+ return init_func (shape [::- 1 ], dtype = dtype , device = device ), state
464
+ else :
465
+ return init_func (shape , dtype = dtype , device = device ), state
462
466
elif to_order == "col32" :
463
467
# blocks of 32 columns (padded)
464
468
cols = 32 * ((cols + 31 ) // 32 )
@@ -486,6 +490,10 @@ def nvidia_transform(
486
490
state = None ,
487
491
ld = None ,
488
492
):
493
+ if HIP_ENVIRONMENT :
494
+ to_order = "col" if to_order in ["col32" ,"col_turing" ,"col_ampere" ] else to_order
495
+ from_order = "col" if from_order in ["col32" ,"col_turing" ,"col_ampere" ] else from_order
496
+
489
497
if state is None :
490
498
state = (A .shape , from_order )
491
499
else :
@@ -1715,23 +1723,38 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
1715
1723
return torch .empty (tuple (shapeA [:2 ] + [shapeB [0 ]]), device = A .device , dtype = torch .float16 )
1716
1724
1717
1725
if dimsA == 2 and out is None :
1718
- out , Sout = get_transform_buffer (
1719
- (shapeA [0 ], shapeB [0 ]), dtype , A .device , "col32" , "row"
1720
- )
1726
+ if HIP_ENVIRONMENT :
1727
+ out , Sout = get_transform_buffer (
1728
+ (shapeA [0 ], shapeB [0 ]), dtype , A .device , "col" , "row"
1729
+ )
1730
+ else :
1731
+ out , Sout = get_transform_buffer (
1732
+ (shapeA [0 ], shapeB [0 ]), dtype , A .device , "col32" , "row"
1733
+ )
1721
1734
elif dimsA == 3 and out is None :
1722
- out , Sout = get_transform_buffer (
1723
- (shapeA [0 ], shapeA [1 ], shapeB [0 ]), dtype , A .device , "col32" , "row"
1724
- )
1735
+ if HIP_ENVIRONMENT :
1736
+ out , Sout = get_transform_buffer (
1737
+ (shapeA [0 ], shapeA [1 ], shapeB [0 ]), dtype , A .device , "col" , "row"
1738
+ )
1739
+ else :
1740
+ out , Sout = get_transform_buffer (
1741
+ (shapeA [0 ], shapeA [1 ], shapeB [0 ]), dtype , A .device , "col32" , "row"
1742
+ )
1725
1743
1726
1744
assert dimsB != 3 , "len(B.shape)==3 not supported"
1727
1745
assert A .device .type == "cuda"
1728
1746
assert B .device .type == "cuda"
1729
1747
assert A .dtype == torch .int8
1730
1748
assert B .dtype == torch .int8
1731
1749
assert out .dtype == dtype
1732
- assert SA [1 ] == "col32"
1733
- assert SB [1 ] in ["col_turing" , "col_ampere" ]
1734
- assert Sout [1 ] == "col32"
1750
+ if HIP_ENVIRONMENT :
1751
+ assert SA [1 ] == "col"
1752
+ assert SB [1 ] == "col"
1753
+ assert Sout [1 ] == "col"
1754
+ else :
1755
+ assert SA [1 ] == "col32"
1756
+ assert SB [1 ] in ["col_turing" , "col_ampere" ]
1757
+ assert Sout [1 ] == "col32"
1735
1758
assert (
1736
1759
shapeA [- 1 ] == shapeB [- 1 ]
1737
1760
), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = { shapeA } @ { shapeB } "
@@ -1745,25 +1768,29 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
1745
1768
ptrC = get_ptr (out )
1746
1769
1747
1770
k = shapeA [- 1 ]
1748
- lda = ct .c_int32 (m * 32 )
1749
- if formatB == "col_turing" :
1750
- # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
1751
- # n = rows
1752
- ldb = ct .c_int32 (((rows + 7 ) // 8 ) * 8 * 32 )
1771
+ if HIP_ENVIRONMENT :
1772
+ lda = ct .c_int32 (m )
1773
+ ldb = ct .c_int32 (shapeB [0 ])
1774
+ ldc = ct .c_int32 (m )
1753
1775
else :
1754
- # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
1755
- # n = rows
1756
- ldb = ct .c_int32 (((rows + 31 ) // 32 ) * 32 * 32 )
1757
-
1758
- ldc = ct .c_int32 (m * 32 )
1776
+ lda = ct .c_int32 (m * 32 )
1777
+ if formatB == "col_turing" :
1778
+ # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
1779
+ # n = rows
1780
+ ldb = ct .c_int32 (((rows + 7 ) // 8 ) * 8 * 32 )
1781
+ else :
1782
+ # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
1783
+ # n = rows
1784
+ ldb = ct .c_int32 (((rows + 31 ) // 32 ) * 32 * 32 )
1785
+ ldc = ct .c_int32 (m * 32 )
1759
1786
m = ct .c_int32 (m )
1760
1787
n = ct .c_int32 (n )
1761
1788
k = ct .c_int32 (k )
1762
1789
1763
1790
has_error = 0
1764
1791
ptrRowScale = get_ptr (None )
1765
1792
is_on_gpu ([A , B , out ])
1766
- if formatB == 'col_turing' :
1793
+ if formatB == 'col_turing' or HIP_ENVIRONMENT :
1767
1794
if dtype == torch .int32 :
1768
1795
has_error = lib .cigemmlt_turing_32 (
1769
1796
ptr , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc
@@ -2072,6 +2099,9 @@ def double_quant(
2072
2099
2073
2100
2074
2101
def transform (A , to_order , from_order = 'row' , out = None , transpose = False , state = None , ld = None ):
2102
+ if HIP_ENVIRONMENT :
2103
+ return nvidia_transform (A ,to_order ,from_order ,out ,transpose ,state ,ld )
2104
+
2075
2105
prev_device = pre_call (A .device )
2076
2106
if state is None : state = (A .shape , from_order )
2077
2107
else : from_order = state [1 ]
0 commit comments