Skip to content

Commit 7198d6b

Browse files
authored
Merge pull request bitsandbytes-foundation#3 from pnunna93/fix_igemmlt_int
Enable igemmlt int test on rocm
2 parents 71bf2df + 42b860f commit 7198d6b

File tree

4 files changed

+71
-35
lines changed

4 files changed

+71
-35
lines changed

bitsandbytes/functional.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Tuple
1717
from torch import Tensor
1818

19-
from .cextension import COMPILED_WITH_CUDA, lib
19+
from .cextension import COMPILED_WITH_CUDA, lib, HIP_ENVIRONMENT
2020

2121
# Remark: for AMD GPU we need to disable blocksize == 64
2222

@@ -458,7 +458,11 @@ def get_transform_buffer(
458458
state = (shape[::-1], to_order)
459459

460460
if to_order == "row" or to_order == "col":
461-
return init_func(shape, dtype=dtype, device=device), state
461+
if HIP_ENVIRONMENT and to_order == "col":
462+
# row to col transformation transposes output shape, so change buffer allocation accordingly
463+
return init_func(shape[::-1], dtype=dtype, device=device), state
464+
else:
465+
return init_func(shape, dtype=dtype, device=device), state
462466
elif to_order == "col32":
463467
# blocks of 32 columns (padded)
464468
cols = 32 * ((cols + 31) // 32)
@@ -486,6 +490,10 @@ def nvidia_transform(
486490
state=None,
487491
ld=None,
488492
):
493+
if HIP_ENVIRONMENT:
494+
to_order = "col" if to_order in ["col32","col_turing","col_ampere"] else to_order
495+
from_order = "col" if from_order in ["col32","col_turing","col_ampere"] else from_order
496+
489497
if state is None:
490498
state = (A.shape, from_order)
491499
else:
@@ -1715,23 +1723,38 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
17151723
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
17161724

17171725
if dimsA == 2 and out is None:
1718-
out, Sout = get_transform_buffer(
1719-
(shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
1720-
)
1726+
if HIP_ENVIRONMENT:
1727+
out, Sout = get_transform_buffer(
1728+
(shapeA[0], shapeB[0]), dtype, A.device, "col", "row"
1729+
)
1730+
else:
1731+
out, Sout = get_transform_buffer(
1732+
(shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
1733+
)
17211734
elif dimsA == 3 and out is None:
1722-
out, Sout = get_transform_buffer(
1723-
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
1724-
)
1735+
if HIP_ENVIRONMENT:
1736+
out, Sout = get_transform_buffer(
1737+
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col", "row"
1738+
)
1739+
else:
1740+
out, Sout = get_transform_buffer(
1741+
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
1742+
)
17251743

17261744
assert dimsB != 3, "len(B.shape)==3 not supported"
17271745
assert A.device.type == "cuda"
17281746
assert B.device.type == "cuda"
17291747
assert A.dtype == torch.int8
17301748
assert B.dtype == torch.int8
17311749
assert out.dtype == dtype
1732-
assert SA[1] == "col32"
1733-
assert SB[1] in ["col_turing", "col_ampere"]
1734-
assert Sout[1] == "col32"
1750+
if HIP_ENVIRONMENT:
1751+
assert SA[1] == "col"
1752+
assert SB[1] == "col"
1753+
assert Sout[1] == "col"
1754+
else:
1755+
assert SA[1] == "col32"
1756+
assert SB[1] in ["col_turing", "col_ampere"]
1757+
assert Sout[1] == "col32"
17351758
assert (
17361759
shapeA[-1] == shapeB[-1]
17371760
), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
@@ -1745,25 +1768,29 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
17451768
ptrC = get_ptr(out)
17461769

17471770
k = shapeA[-1]
1748-
lda = ct.c_int32(m * 32)
1749-
if formatB == "col_turing":
1750-
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
1751-
# n = rows
1752-
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
1771+
if HIP_ENVIRONMENT:
1772+
lda = ct.c_int32(m)
1773+
ldb = ct.c_int32(shapeB[0])
1774+
ldc = ct.c_int32(m)
17531775
else:
1754-
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
1755-
# n = rows
1756-
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
1757-
1758-
ldc = ct.c_int32(m * 32)
1776+
lda = ct.c_int32(m * 32)
1777+
if formatB == "col_turing":
1778+
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
1779+
# n = rows
1780+
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
1781+
else:
1782+
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
1783+
# n = rows
1784+
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
1785+
ldc = ct.c_int32(m * 32)
17591786
m = ct.c_int32(m)
17601787
n = ct.c_int32(n)
17611788
k = ct.c_int32(k)
17621789

17631790
has_error = 0
17641791
ptrRowScale = get_ptr(None)
17651792
is_on_gpu([A, B, out])
1766-
if formatB == 'col_turing':
1793+
if formatB == 'col_turing' or HIP_ENVIRONMENT:
17671794
if dtype == torch.int32:
17681795
has_error = lib.cigemmlt_turing_32(
17691796
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
@@ -2072,6 +2099,9 @@ def double_quant(
20722099

20732100

20742101
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
2102+
if HIP_ENVIRONMENT:
2103+
return nvidia_transform(A,to_order,from_order,out,transpose,state,ld)
2104+
20752105
prev_device = pre_call(A.device)
20762106
if state is None: state = (A.shape, from_order)
20772107
else: from_order = state[1]

csrc/ops.hip

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ template void transform<int8_t, ROW, COL_TURING, false, 8>(hipblasLtHandle_t ltH
431431
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
432432
template void transform<int8_t, COL32, ROW, false, 8>(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
433433
template void transform<int32_t, COL32, ROW, false, 32>(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
434+
template void transform<int8_t, COL, ROW, false, 8>(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
435+
template void transform<int32_t, COL, ROW, false, 32>(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
434436
#endif
435437
static std::string hipError_to_string(const hipError_t ret)
436438
{

csrc/pythonInterface.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
145145
MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
146146
MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
147147
MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
148+
MAKE_FUNC_TRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8);
149+
MAKE_FUNC_TRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32);
148150
#endif
149151

150152
void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
@@ -381,6 +383,8 @@ extern "C"
381383
MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
382384
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
383385
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
386+
MAKE_FUNC_CTRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8)
387+
MAKE_FUNC_CTRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32)
384388
#endif
385389
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols)
386390
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); }

tests/test_functional.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -594,15 +594,14 @@ def test_vector_quant(dim1, dim2, dim3):
594594
# dim1, dim2 = (256,), (256,)
595595
dtype = [torch.int8, torch.int32]
596596
a_order = ["row"]
597-
out_order = ["col", "row", "col32"]
597+
out_order = ["col", "row"] if HIP_ENVIRONMENT else ["col", "row", "col32"]
598598
transpose = [False]
599599
dims = [2, 3]
600600
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
601601

602602
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
603603

604604

605-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
606605
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
607606
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
608607
if dims == 3 and out_order != "col32":
@@ -686,7 +685,6 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
686685
for vals in values
687686
]
688687

689-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
690688
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
691689
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
692690
for i in range(k):
@@ -709,16 +707,19 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
709707
C3, S = F.nvidia_transform(C2, "row", state=SC)
710708
torch.testing.assert_close(C1, C3.float())
711709

712-
# transpose
713-
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
714-
torch.int8
715-
)
716-
C1 = torch.matmul(A.float(), B.float())
710+
# Since ROCm supports row to col transformation only which is same as transpose,
711+
# skipping this for HIP environment
712+
if not HIP_ENVIRONMENT:
713+
## transpose
714+
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
715+
torch.int8
716+
)
717+
C1 = torch.matmul(A.float(), B.float())
717718

718-
B2t, SBt = F.transform(B, "col_turing", transpose=True)
719-
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
720-
C3, S = F.nvidia_transform(C2, "row", state=SC)
721-
torch.testing.assert_close(C1, C3.float())
719+
B2t, SBt = F.transform(B, "col_turing", transpose=True)
720+
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
721+
C3, S = F.nvidia_transform(C2, "row", state=SC)
722+
torch.testing.assert_close(C1, C3.float())
722723

723724

724725
dim1 = [32]
@@ -734,7 +735,6 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
734735
for vals in values
735736
]
736737

737-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
738738
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
739739
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
740740
formatB = F.get_special_format_str()

0 commit comments

Comments
 (0)