1111#include < numeric>
1212
1313namespace mlx ::core {
14-
1514namespace {
1615
1716std::tuple<bool , int64_t , array>
@@ -29,20 +28,41 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
2928 }
3029}
3130
32- void gemm_and_bias (
33- cu::CommandEncoder& encoder,
34- int M,
35- int N,
36- int K,
37- bool a_transposed,
38- int64_t lda,
39- bool b_transposed,
40- int64_t ldb,
41- array& out,
42- const array& a,
43- const array& b,
44- void * bias = nullptr ) {
31+ } // namespace
32+
33+ void Matmul::eval_gpu (const std::vector<array>& inputs, array& out) {
34+ nvtx3::scoped_range r (" Matmul::eval_gpu" );
35+ auto & s = stream ();
36+ auto & encoder = cu::get_command_encoder (s);
37+
38+ assert (inputs.size () == 2 );
39+ auto & a_pre = inputs[0 ];
40+ auto & b_pre = inputs[1 ];
41+ // Return 0s if either input is empty.
42+ if (a_pre.size () == 0 || b_pre.size () == 0 ) {
43+ array zero (0 , a_pre.dtype ());
44+ encoder.add_temporary (zero);
45+ fill_gpu (zero, out, s);
46+ return ;
47+ }
48+
49+ out.set_data (allocator::malloc (out.nbytes ()));
50+
51+ // ///////////////////////////////////////////////////////////////////////////
52+ // Init checks and prep
53+
54+ int M = a_pre.shape (-2 );
55+ int N = b_pre.shape (-1 );
56+ int K = a_pre.shape (-1 );
57+
58+ // Keep a vector with copies to be cleared in the completed buffer to release
59+ // the arrays
60+ auto [a_transposed, lda, a] = check_transpose (encoder, s, a_pre);
61+ auto [b_transposed, ldb, b] = check_transpose (encoder, s, b_pre);
62+
63+ // ///////////////////////////////////////////////////////////////////////////
4564 // Check and collapse batch dimensions
65+
4666 auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches (a, b);
4767
4868 auto batch_count = out.size () / (M * N);
@@ -59,8 +79,7 @@ void gemm_and_bias(
5979 batch_shape = {1 };
6080 }
6181
62- // Use gemmv when possible
63- if (!bias && cu::can_use_gemv (M, N, K, a_transposed, b_transposed)) {
82+ if (cu::can_use_gemv (M, N, K, a_transposed, b_transposed)) {
6483 cu::gemv (
6584 a,
6685 b,
@@ -76,9 +95,10 @@ void gemm_and_bias(
7695 return ;
7796 }
7897
98+ // ///////////////////////////////////////////////////////////////////////////
7999 // Invoke cublasLt
80100 CublasGemm gemm (
81- encoder. device (),
101+ cu:: device (s. device ),
82102 a.dtype (),
83103 a_transposed,
84104 M,
@@ -91,45 +111,9 @@ void gemm_and_bias(
91111 batch_shape.back (),
92112 a_batch_strides.back (),
93113 b_batch_strides.back ());
94- if (bias) {
95- gemm.set_bias (bias);
96- }
97114 gemm.run (encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
98115}
99116
100- } // namespace
101-
102- void Matmul::eval_gpu (const std::vector<array>& inputs, array& out) {
103- nvtx3::scoped_range r (" Matmul::eval_gpu" );
104- auto & s = stream ();
105- auto & encoder = cu::get_command_encoder (s);
106-
107- assert (inputs.size () == 2 );
108- auto & a_pre = inputs[0 ];
109- auto & b_pre = inputs[1 ];
110- // Return 0s if either input is empty.
111- if (a_pre.size () == 0 || b_pre.size () == 0 ) {
112- array zero (0 , a_pre.dtype ());
113- encoder.add_temporary (zero);
114- fill_gpu (zero, out, s);
115- return ;
116- }
117-
118- out.set_data (allocator::malloc (out.nbytes ()));
119-
120- int M = a_pre.shape (-2 );
121- int N = b_pre.shape (-1 );
122- int K = a_pre.shape (-1 );
123-
124- // Keep a vector with copies to be cleared in the completed buffer to release
125- // the arrays
126- auto [a_transposed, lda, a] = check_transpose (encoder, s, a_pre);
127- auto [b_transposed, ldb, b] = check_transpose (encoder, s, b_pre);
128-
129- gemm_and_bias (
130- encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
131- }
132-
133117void AddMM::eval_gpu (const std::vector<array>& inputs, array& out) {
134118 nvtx3::scoped_range r (" AddMM::eval_gpu" );
135119 auto & s = stream ();
@@ -152,27 +136,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
152136 auto [a_transposed, lda, a] = check_transpose (encoder, s, a_pre);
153137 auto [b_transposed, ldb, b] = check_transpose (encoder, s, b_pre);
154138
155- // ///////////////////////////////////////////////////////////////////////////
156- // Dispatch to GEMM with epilogue or AddMM
157-
158- if (beta_ == 1 && c.strides (-1 ) == 1 && c.data_size () == out.shape (-1 )) {
159- out.set_data (allocator::malloc (out.nbytes ()));
160- gemm_and_bias (
161- encoder,
162- M,
163- N,
164- K,
165- a_transposed,
166- lda,
167- b_transposed,
168- ldb,
169- out,
170- a,
171- b,
172- c.data <void >());
173- return ;
174- }
175-
176139 int64_t ldc;
177140 {
178141 auto stx = c.strides ()[c.ndim () - 2 ];
@@ -214,7 +177,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
214177 }
215178
216179 // ///////////////////////////////////////////////////////////////////////////
217- // Invoke cublasLt with AddMM settings
180+ // Invoke cublasLt
218181
219182 CublasGemm gemm (
220183 cu::device (s.device ),
0 commit comments