Skip to content

Commit 7f97c39

Browse files
author
Awni Hannun
committed
fix older cuda
1 parent 195acec commit 7f97c39

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

mlx/backend/cuda/quantized/cuda_fp4.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,16 @@ struct __nv_fp4_e2m1 {
8181
}
8282
uint8_t __x{0};
8383
};
84+
85+
struct __nv_fp4x4_e2m1 {
86+
__device__ operator float4() {
87+
float4 out;
88+
out.x = float(*(__nv_fp4_e2m1*)(__high));
89+
out.y = float(*(__nv_fp4_e2m1*)(__high >> 4));
90+
out.z = float(*(__nv_fp4_e2m1*)(__low));
91+
out.w = float(*(__nv_fp4_e2m1*)(__low >> 4));
92+
return out;
93+
}
94+
uint8_t __high{0};
95+
uint8_t __low{0};
96+
};

0 commit comments

Comments
 (0)