-
-
Notifications
You must be signed in to change notification settings - Fork 114
Cublaslt fp8 #450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Cublaslt fp8 #450
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for fp8 matrix multiplication using cublaslt
. The changes include updates to Cargo.toml
for new features and dependencies, and new DeviceRepr
implementations for fp8/fp4 types. The core logic is in src/cublaslt/safe.rs
.
I've identified several critical issues in the implementation. There are problems with how scale pointers are handled in set_fp8_scale
, incorrect hardcoded data types and compute types within the Fp8Matmul
trait that will not work for bf16
, and incorrect dimensions and leading dimensions in the new test case. These issues are likely to cause runtime errors or produce incorrect results and should be addressed.
src/cublaslt/safe.rs
Outdated
fn set_fp8_scale( | ||
&self, | ||
scale_ptr: &impl DevicePtr<f32>, | ||
scale_mode: ScaleMode, | ||
matrix: Matrix, | ||
) -> Result<(), CublasError> { | ||
let scale_ptr_attr = match matrix { | ||
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, | ||
Matrix::D => { | ||
return Err(CublasError( | ||
sys::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE, | ||
)) | ||
} | ||
}; | ||
|
||
let scale_mode_attr = match matrix { | ||
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_MODE, | ||
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_MODE, | ||
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_MODE, | ||
Matrix::D => { | ||
return Err(CublasError( | ||
sys::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE, | ||
)) | ||
} | ||
}; | ||
|
||
let scale_mode = match scale_mode { | ||
ScaleMode::Scalar32f => { | ||
sys::cublasLtMatmulMatrixScale_t::CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F | ||
} | ||
ScaleMode::RowWise32f => { | ||
sys::cublasLtMatmulMatrixScale_t::CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F | ||
} | ||
}; | ||
|
||
let scale_elems = 1; | ||
|
||
unsafe { | ||
result::set_matmul_desc_attribute( | ||
self.handle, | ||
scale_ptr_attr, | ||
scale_ptr as *const _ as *const _, | ||
mem::size_of::<f32>() * scale_elems, | ||
)?; | ||
|
||
result::set_matmul_desc_attribute( | ||
self.handle, | ||
scale_mode_attr, | ||
(&scale_mode) as *const _ as *const _, | ||
mem::size_of::<sys::cublasLtMatmulMatrixScale_t>(), | ||
)?; | ||
} | ||
Ok(()) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of set_fp8_scale
has a critical bug in how it handles the scale_ptr
. The cast scale_ptr as *const _ as *const _
is incorrect for a trait object &impl DevicePtr<f32>
and will not yield the device pointer. Additionally, the size passed to set_matmul_desc_attribute
for the scale pointer is mem::size_of::<f32>()
, but it should be mem::size_of::<sys::CUdeviceptr>()
as you are setting a pointer attribute.
A possible fix is to change the function signature to accept a &sys::CUdeviceptr
and have the caller, fp8_matmul
, extract the pointer first using .device_ptr(stream)
.
src/cublaslt/safe.rs
Outdated
fn c_matrix_type() -> sys::cudaDataType { | ||
sys::cudaDataType::CUDA_R_16F | ||
} | ||
|
||
fn d_matrix_type() -> sys::cudaDataType { | ||
sys::cudaDataType::CUDA_R_16F | ||
} | ||
|
||
fn compute_type() -> sys::cublasComputeType_t { | ||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The c_matrix_type
, d_matrix_type
, and compute_type
functions are not correctly defined for the Fp8Matmul<T>
trait.
-
c_matrix_type
andd_matrix_type
are hardcoded to returnCUDA_R_16F
. This is incorrect whenT
ishalf::bf16
, which should useCUDA_R_16BF
. These should be defined as associated functions on the trait without a default implementation, and then implemented for each specific type (f16
,bf16
). -
compute_type
is hardcoded toCUBLAS_COMPUTE_16F
. According to the cuBLASLt documentation, FP8 GEMMs require the compute type to beCUBLAS_COMPUTE_32F
. UsingCUBLAS_COMPUTE_16F
may lead to errors or incorrect results.
A correct implementation would look like this:
pub trait Fp8Matmul<T>: MatmulShared {
// ...
fn c_matrix_type() -> sys::cudaDataType;
fn d_matrix_type() -> sys::cudaDataType;
fn compute_type() -> sys::cublasComputeType_t {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
}
// ...
}
impl Fp8Matmul<half::f16> for CudaBlasLT {
fn c_matrix_type() -> sys::cudaDataType {
sys::cudaDataType::CUDA_R_16F
}
fn d_matrix_type() -> sys::cudaDataType {
sys::cudaDataType::CUDA_R_16F
}
}
impl Fp8Matmul<half::bf16> for CudaBlasLT {
fn c_matrix_type() -> sys::cudaDataType {
sys::cudaDataType::CUDA_R_16BF
}
fn d_matrix_type() -> sys::cudaDataType {
sys::cudaDataType::CUDA_R_16BF
}
}
src/cublaslt/safe.rs
Outdated
MatmulConfig { | ||
transa: false, | ||
transb: false, | ||
transc: false, | ||
m: N as u64, | ||
n: M as u64, | ||
k: K as u64, | ||
alpha: 1.0, | ||
lda: N as i64, | ||
ldb: K as i64, | ||
beta: 0.0, | ||
ldc: N as i64, | ||
stride_a: None, | ||
stride_b: None, | ||
stride_c: None, | ||
stride_bias: None, | ||
batch_size: None, | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MatmulConfig
in test_matmul_fp8_scalar_scale
has incorrect dimensions. The matrix dimensions m
and n
are swapped, and the leading dimensions lda
and ldb
are incorrect for the given matrix layouts. This will cause out-of-bounds memory access and incorrect test results.
For a matrix multiplication C(M,N) = A(M,K) @ B(K,N)
, the config should be m=M, n=N, k=K
. For row-major matrices, lda
should be K
, ldb
should be N
, and ldc
should be N
.
MatmulConfig { | |
transa: false, | |
transb: false, | |
transc: false, | |
m: N as u64, | |
n: M as u64, | |
k: K as u64, | |
alpha: 1.0, | |
lda: N as i64, | |
ldb: K as i64, | |
beta: 0.0, | |
ldc: N as i64, | |
stride_a: None, | |
stride_b: None, | |
stride_c: None, | |
stride_bias: None, | |
batch_size: None, | |
}, | |
MatmulConfig { | |
transa: false, | |
transb: false, | |
transc: false, | |
m: M as u64, | |
n: N as u64, | |
k: K as u64, | |
alpha: 1.0, | |
lda: K as i64, | |
ldb: N as i64, | |
beta: 0.0, | |
ldc: N as i64, | |
stride_a: None, | |
stride_b: None, | |
stride_c: None, | |
stride_bias: None, | |
batch_size: None, | |
}, |
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FP8 matrix multiplication using cublaslt
. The changes include adding new dependencies and features in Cargo.toml
, implementing DeviceRepr
for FP8 types, and adding a new Fp8Matmul
trait with its implementation and tests. The implementation is a good start, but there are a few critical and high-severity issues that need to be addressed for correctness and robustness. Specifically, handling of matrix D
scaling attributes is incorrect, the d_matrix_type
is unused, the bias
parameter signature is too restrictive, and there are a couple of logic errors in the new tests.
src/cublaslt/safe.rs
Outdated
let scale_ptr_attr = match matrix { | ||
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, | ||
Matrix::D => { | ||
return Err(CublasError( | ||
sys::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE, | ||
)) | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The match
statements for scale_ptr_attr
and scale_mode_attr
incorrectly return an error for Matrix::D
. The cublasLtMatmulDescAttributes_t
enum defines CUBLASLT_MATMUL_DESC_D_SCALE_POINTER
and CUBLASLT_MATMUL_DESC_D_SCALE_MODE
, so Matrix::D
should be handled. This prevents setting scaling for the output matrix D, which is a critical part of FP8 matrix multiplication. A similar fix is needed for scale_mode_attr
.
let scale_ptr_attr = match matrix {
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
Matrix::D => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
};
src/cublaslt/safe.rs
Outdated
b_scale: &impl DevicePtr<f32>, | ||
b_scale_mode: ScaleMode, | ||
c: &mut O, | ||
bias: Option<&O>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bias
parameter has type Option<&O>
, where O: DevicePtrMut<T>
. This is overly restrictive as bias
is an input and should not require mutable access. It also forces the bias to be of the same type as the output matrix c
. The signature should be bias: Option<&impl DevicePtr<T>>
to allow for immutable input and more flexibility, similar to the existing matmul
function.
bias: Option<&impl DevicePtr<T>>,
let c_layout = MatrixLayout::new(Self::c_matrix_type(), cfg.m, cfg.n, cfg.ldc)?; | ||
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) { | ||
c_layout.set_batch(batch_size, stride_c)?; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The d_matrix_type
function is defined but never used. The get_matmul_algo_heuristic
function is called with c_layout.handle
for both the C and D matrix layouts. While this works currently because c_matrix_type
and d_matrix_type
return the same type, it's fragile and incorrect if they differ. A separate d_layout
should be created using d_matrix_type
and passed to get_matmul_algo_heuristic
and result::matmul
to make the code more robust and correct.
let c_layout = MatrixLayout::new(Self::c_matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
c_layout.set_batch(batch_size, stride_c)?;
}
let d_layout = MatrixLayout::new(Self::d_matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
d_layout.set_batch(batch_size, stride_c)?;
}
src/cublaslt/safe.rs
Outdated
let found = c_host[m * N + n]; | ||
let expected = c[m][n]; | ||
assert!( | ||
(half::f16::to_f32(found) - half::f16::to_f32(expected)) <= 1e-2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion is missing an absolute value function (.abs()
) when checking the difference between found
and expected
. This could cause tests to pass even if the result is incorrect but happens to be smaller than the expected value.
(half::f16::to_f32(found) - half::f16::to_f32(expected)).abs() <= 1e-2,
src/cublaslt/safe.rs
Outdated
for m in 0..M { | ||
for n in 0..N { | ||
c[m][n] = beta; | ||
|
||
for k in 0..K { | ||
let a = a[m][k].to_f32() * half::f16::to_f32(a_scale); | ||
let b = b[k][n].to_f32() * half::f16::to_f32(b_scale); | ||
c[m][n] += half::f16::from_f32(alpha.to_f32() * a * b); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for applying beta
is incorrect. It sets c[m][n] = beta;
, which overwrites the input value of c
. The correct operation for D = alpha*A*B + beta*C
is to scale the input C
by beta
, i.e., c[m][n] *= beta;
. Although the current test passes because beta
is 0, this is a latent bug that could cause future tests to pass incorrectly.
for m in 0..M {
for n in 0..N {
c[m][n] *= beta;
for k in 0..K {
let a = a[m][k].to_f32() * half::f16::to_f32(a_scale);
let b = b[k][n].to_f32() * half::f16::to_f32(b_scale);
c[m][n] += half::f16::from_f32(alpha.to_f32() * a * b);
}
}
}
@coreylowman had to pull ndarray into dev-dependencies, is that ok? computing reference matmuls was getting a little gnarly. |
@zackangelo yep! |
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FP8 matrix multiplication using cublaslt
. It adds a new Fp8Matmul
trait, along with necessary supporting types and test cases. The changes are well-structured and follow the existing patterns in the codebase.
My review focuses on two main points:
- The versions for
ndarray
andndarray-rand
inCargo.toml
seem to point to unreleased versions, which will cause build failures. - The new FP8 functionality is not gated by the required CUDA version, which could lead to compilation errors for users with older CUDA toolkits.
Addressing these issues will ensure the new functionality is robust and integrates well with the existing crate.
ndarray = "0.16" | ||
ndarray-rand = "0.15" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#[cfg(feature = "f8")] | ||
pub trait Fp8Matmul<T>: MatmulShared { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The f8
feature should be coupled with the appropriate CUDA version feature gates to prevent compilation errors when using older CUDA versions that do not support FP8 operations. FP8 support was introduced in CUDA 11.8. This gating should be applied to all FP8-related code, including the ScaleMode
enum and the f8_tests
module.
#[cfg(feature = "f8")] | |
pub trait Fp8Matmul<T>: MatmulShared { | |
#[cfg(all(feature = "f8", any(feature = "cuda-11080", feature = "cuda-12000", feature = "cuda-12010", feature = "cuda-12020", feature = "cuda-12030", feature = "cuda-12040", feature = "cuda-12050", feature = "cuda-12060", feature = "cuda-12080", feature = "cuda-12090", feature = "cuda-13000")))] | |
pub trait Fp8Matmul<T>: MatmulShared { |
WIP