Skip to content

Conversation

zackangelo
Copy link
Contributor

WIP

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for fp8 matrix multiplication using cublaslt. The changes include updates to Cargo.toml for new features and dependencies, and new DeviceRepr implementations for fp8/fp4 types. The core logic is in src/cublaslt/safe.rs.

I've identified several critical issues in the implementation. There are problems with how scale pointers are handled in set_fp8_scale, incorrect hardcoded data types and compute types within the Fp8Matmul trait that will not work for bf16, and incorrect dimensions and leading dimensions in the new test case. These issues are likely to cause runtime errors or produce incorrect results and should be addressed.

Comment on lines 250 to 305
fn set_fp8_scale(
&self,
scale_ptr: &impl DevicePtr<f32>,
scale_mode: ScaleMode,
matrix: Matrix,
) -> Result<(), CublasError> {
let scale_ptr_attr = match matrix {
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
Matrix::D => {
return Err(CublasError(
sys::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE,
))
}
};

let scale_mode_attr = match matrix {
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_MODE,
Matrix::D => {
return Err(CublasError(
sys::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE,
))
}
};

let scale_mode = match scale_mode {
ScaleMode::Scalar32f => {
sys::cublasLtMatmulMatrixScale_t::CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
}
ScaleMode::RowWise32f => {
sys::cublasLtMatmulMatrixScale_t::CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
}
};

let scale_elems = 1;

unsafe {
result::set_matmul_desc_attribute(
self.handle,
scale_ptr_attr,
scale_ptr as *const _ as *const _,
mem::size_of::<f32>() * scale_elems,
)?;

result::set_matmul_desc_attribute(
self.handle,
scale_mode_attr,
(&scale_mode) as *const _ as *const _,
mem::size_of::<sys::cublasLtMatmulMatrixScale_t>(),
)?;
}
Ok(())
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of set_fp8_scale has a critical bug in how it handles the scale_ptr. The cast scale_ptr as *const _ as *const _ is incorrect for a trait object &impl DevicePtr<f32> and will not yield the device pointer. Additionally, the size passed to set_matmul_desc_attribute for the scale pointer is mem::size_of::<f32>(), but it should be mem::size_of::<sys::CUdeviceptr>() as you are setting a pointer attribute.

A possible fix is to change the function signature to accept a &sys::CUdeviceptr and have the caller, fp8_matmul, extract the pointer first using .device_ptr(stream).

Comment on lines 545 to 555
fn c_matrix_type() -> sys::cudaDataType {
sys::cudaDataType::CUDA_R_16F
}

fn d_matrix_type() -> sys::cudaDataType {
sys::cudaDataType::CUDA_R_16F
}

fn compute_type() -> sys::cublasComputeType_t {
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The c_matrix_type, d_matrix_type, and compute_type functions are not correctly defined for the Fp8Matmul<T> trait.

  1. c_matrix_type and d_matrix_type are hardcoded to return CUDA_R_16F. This is incorrect when T is half::bf16, which should use CUDA_R_16BF. These should be defined as associated functions on the trait without a default implementation, and then implemented for each specific type (f16, bf16).

  2. compute_type is hardcoded to CUBLAS_COMPUTE_16F. According to the cuBLASLt documentation, FP8 GEMMs require the compute type to be CUBLAS_COMPUTE_32F. Using CUBLAS_COMPUTE_16F may lead to errors or incorrect results.

A correct implementation would look like this:

pub trait Fp8Matmul<T>: MatmulShared {
    // ...
    fn c_matrix_type() -> sys::cudaDataType;
    fn d_matrix_type() -> sys::cudaDataType;

    fn compute_type() -> sys::cublasComputeType_t {
        sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
    }
    // ...
}

impl Fp8Matmul<half::f16> for CudaBlasLT {
    fn c_matrix_type() -> sys::cudaDataType {
        sys::cudaDataType::CUDA_R_16F
    }
    fn d_matrix_type() -> sys::cudaDataType {
        sys::cudaDataType::CUDA_R_16F
    }
}
impl Fp8Matmul<half::bf16> for CudaBlasLT {
    fn c_matrix_type() -> sys::cudaDataType {
        sys::cudaDataType::CUDA_R_16BF
    }
    fn d_matrix_type() -> sys::cudaDataType {
        sys::cudaDataType::CUDA_R_16BF
    }
}

Comment on lines 1049 to 1066
MatmulConfig {
transa: false,
transb: false,
transc: false,
m: N as u64,
n: M as u64,
k: K as u64,
alpha: 1.0,
lda: N as i64,
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The MatmulConfig in test_matmul_fp8_scalar_scale has incorrect dimensions. The matrix dimensions m and n are swapped, and the leading dimensions lda and ldb are incorrect for the given matrix layouts. This will cause out-of-bounds memory access and incorrect test results.

For a matrix multiplication C(M,N) = A(M,K) @ B(K,N), the config should be m=M, n=N, k=K. For row-major matrices, lda should be K, ldb should be N, and ldc should be N.

Suggested change
MatmulConfig {
transa: false,
transb: false,
transc: false,
m: N as u64,
n: M as u64,
k: K as u64,
alpha: 1.0,
lda: N as i64,
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
MatmulConfig {
transa: false,
transb: false,
transc: false,
m: M as u64,
n: N as u64,
k: K as u64,
alpha: 1.0,
lda: K as i64,
ldb: N as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},

@zackangelo
Copy link
Contributor Author

@gemini-code-assist review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 matrix multiplication using cublaslt. The changes include adding new dependencies and features in Cargo.toml, implementing DeviceRepr for FP8 types, and adding a new Fp8Matmul trait with its implementation and tests. The implementation is a good start, but there are a few critical and high-severity issues that need to be addressed for correctness and robustness. Specifically, handling of matrix D scaling attributes is incorrect, the d_matrix_type is unused, the bias parameter signature is too restrictive, and there are a couple of logic errors in the new tests.

Comment on lines 256 to 265
let scale_ptr_attr = match matrix {
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
Matrix::D => {
return Err(CublasError(
sys::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE,
))
}
};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The match statements for scale_ptr_attr and scale_mode_attr incorrectly return an error for Matrix::D. The cublasLtMatmulDescAttributes_t enum defines CUBLASLT_MATMUL_DESC_D_SCALE_POINTER and CUBLASLT_MATMUL_DESC_D_SCALE_MODE, so Matrix::D should be handled. This prevents setting scaling for the output matrix D, which is a critical part of FP8 matrix multiplication. A similar fix is needed for scale_mode_attr.

        let scale_ptr_attr = match matrix {
            Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
            Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
            Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
            Matrix::D => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
        };

b_scale: &impl DevicePtr<f32>,
b_scale_mode: ScaleMode,
c: &mut O,
bias: Option<&O>,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The bias parameter has type Option<&O>, where O: DevicePtrMut<T>. This is overly restrictive as bias is an input and should not require mutable access. It also forces the bias to be of the same type as the output matrix c. The signature should be bias: Option<&impl DevicePtr<T>> to allow for immutable input and more flexibility, similar to the existing matmul function.

        bias: Option<&impl DevicePtr<T>>,

Comment on lines +596 to +599
let c_layout = MatrixLayout::new(Self::c_matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
c_layout.set_batch(batch_size, stride_c)?;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The d_matrix_type function is defined but never used. The get_matmul_algo_heuristic function is called with c_layout.handle for both the C and D matrix layouts. While this works currently because c_matrix_type and d_matrix_type return the same type, it's fragile and incorrect if they differ. A separate d_layout should be created using d_matrix_type and passed to get_matmul_algo_heuristic and result::matmul to make the code more robust and correct.

        let c_layout = MatrixLayout::new(Self::c_matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
        if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
            c_layout.set_batch(batch_size, stride_c)?;
        }

        let d_layout = MatrixLayout::new(Self::d_matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
        if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
            d_layout.set_batch(batch_size, stride_c)?;
        }

let found = c_host[m * N + n];
let expected = c[m][n];
assert!(
(half::f16::to_f32(found) - half::f16::to_f32(expected)) <= 1e-2,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion is missing an absolute value function (.abs()) when checking the difference between found and expected. This could cause tests to pass even if the result is incorrect but happens to be smaller than the expected value.

                    (half::f16::to_f32(found) - half::f16::to_f32(expected)).abs() <= 1e-2,

Comment on lines 1136 to 1146
for m in 0..M {
for n in 0..N {
c[m][n] = beta;

for k in 0..K {
let a = a[m][k].to_f32() * half::f16::to_f32(a_scale);
let b = b[k][n].to_f32() * half::f16::to_f32(b_scale);
c[m][n] += half::f16::from_f32(alpha.to_f32() * a * b);
}
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for applying beta is incorrect. It sets c[m][n] = beta;, which overwrites the input value of c. The correct operation for D = alpha*A*B + beta*C is to scale the input C by beta, i.e., c[m][n] *= beta;. Although the current test passes because beta is 0, this is a latent bug that could cause future tests to pass incorrectly.

        for m in 0..M {
            for n in 0..N {
                c[m][n] *= beta;

                for k in 0..K {
                    let a = a[m][k].to_f32() * half::f16::to_f32(a_scale);
                    let b = b[k][n].to_f32() * half::f16::to_f32(b_scale);
                    c[m][n] += half::f16::from_f32(alpha.to_f32() * a * b);
                }
            }
        }

@zackangelo
Copy link
Contributor Author

@coreylowman had to pull ndarray into dev-dependencies, is that ok? computing reference matmuls was getting a little gnarly.

@coreylowman
Copy link
Owner

@coreylowman had to pull ndarray into dev-dependencies, is that ok? computing reference matmuls was getting a little gnarly.

@zackangelo yep!

@zackangelo
Copy link
Contributor Author

@gemini-code-assist review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 matrix multiplication using cublaslt. It adds a new Fp8Matmul trait, along with necessary supporting types and test cases. The changes are well-structured and follow the existing patterns in the codebase.

My review focuses on two main points:

  1. The versions for ndarray and ndarray-rand in Cargo.toml seem to point to unreleased versions, which will cause build failures.
  2. The new FP8 functionality is not gated by the required CUDA version, which could lead to compilation errors for users with older CUDA toolkits.

Addressing these issues will ensure the new functionality is robust and integrates well with the existing crate.

Comment on lines +100 to +101
ndarray = "0.16"
ndarray-rand = "0.15"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The specified versions for ndarray and ndarray-rand do not appear to be published on crates.io. Using non-existent versions will cause the build to fail. Please use the latest stable versions.

ndarray = "0.15"
ndarray-rand = "0.14"

Comment on lines +505 to +506
#[cfg(feature = "f8")]
pub trait Fp8Matmul<T>: MatmulShared {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The f8 feature should be coupled with the appropriate CUDA version feature gates to prevent compilation errors when using older CUDA versions that do not support FP8 operations. FP8 support was introduced in CUDA 11.8. This gating should be applied to all FP8-related code, including the ScaleMode enum and the f8_tests module.

Suggested change
#[cfg(feature = "f8")]
pub trait Fp8Matmul<T>: MatmulShared {
#[cfg(all(feature = "f8", any(feature = "cuda-11080", feature = "cuda-12000", feature = "cuda-12010", feature = "cuda-12020", feature = "cuda-12030", feature = "cuda-12040", feature = "cuda-12050", feature = "cuda-12060", feature = "cuda-12080", feature = "cuda-12090", feature = "cuda-13000")))]
pub trait Fp8Matmul<T>: MatmulShared {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants