1111#include < numeric>
1212
1313namespace mlx ::core {
14+
1415namespace {
1516
1617std::tuple<bool , int64_t , array>
@@ -28,41 +29,20 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
2829 }
2930}
3031
31- } // namespace
32-
33- void Matmul::eval_gpu (const std::vector<array>& inputs, array& out) {
34- nvtx3::scoped_range r (" Matmul::eval_gpu" );
35- auto & s = stream ();
36- auto & encoder = cu::get_command_encoder (s);
37-
38- assert (inputs.size () == 2 );
39- auto & a_pre = inputs[0 ];
40- auto & b_pre = inputs[1 ];
41- // Return 0s if either input is empty.
42- if (a_pre.size () == 0 || b_pre.size () == 0 ) {
43- array zero (0 , a_pre.dtype ());
44- encoder.add_temporary (zero);
45- fill_gpu (zero, out, s);
46- return ;
47- }
48-
49- out.set_data (allocator::malloc (out.nbytes ()));
50-
51- // ///////////////////////////////////////////////////////////////////////////
52- // Init checks and prep
53-
54- int M = a_pre.shape (-2 );
55- int N = b_pre.shape (-1 );
56- int K = a_pre.shape (-1 );
57-
58- // Keep a vector with copies to be cleared in the completed buffer to release
59- // the arrays
60- auto [a_transposed, lda, a] = check_transpose (encoder, s, a_pre);
61- auto [b_transposed, ldb, b] = check_transpose (encoder, s, b_pre);
62-
63- // ///////////////////////////////////////////////////////////////////////////
32+ void gemm_and_bias (
33+ cu::CommandEncoder& encoder,
34+ int M,
35+ int N,
36+ int K,
37+ bool a_transposed,
38+ int64_t lda,
39+ bool b_transposed,
40+ int64_t ldb,
41+ array& out,
42+ const array& a,
43+ const array& b,
44+ void * bias = nullptr ) {
6445 // Check and collapse batch dimensions
65-
6646 auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches (a, b);
6747
6848 auto batch_count = out.size () / (M * N);
@@ -79,7 +59,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
7959 batch_shape = {1 };
8060 }
8161
82- if (cu::can_use_gemv (M, N, K, a_transposed, b_transposed)) {
62+ // Use gemmv when possible
63+ if (!bias && cu::can_use_gemv (M, N, K, a_transposed, b_transposed)) {
8364 cu::gemv (
8465 a,
8566 b,
@@ -95,10 +76,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
9576 return ;
9677 }
9778
98- // ///////////////////////////////////////////////////////////////////////////
9979 // Invoke cublasLt
10080 CublasGemm gemm (
101- cu::device (s .device ),
81+ encoder .device ( ),
10282 a.dtype (),
10383 a_transposed,
10484 M,
@@ -111,9 +91,45 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
11191 batch_shape.back (),
11292 a_batch_strides.back (),
11393 b_batch_strides.back ());
94+ if (bias) {
95+ gemm.set_bias (bias);
96+ }
11497 gemm.run (encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
11598}
11699
100+ } // namespace
101+
102+ void Matmul::eval_gpu (const std::vector<array>& inputs, array& out) {
103+ nvtx3::scoped_range r (" Matmul::eval_gpu" );
104+ auto & s = stream ();
105+ auto & encoder = cu::get_command_encoder (s);
106+
107+ assert (inputs.size () == 2 );
108+ auto & a_pre = inputs[0 ];
109+ auto & b_pre = inputs[1 ];
110+ // Return 0s if either input is empty.
111+ if (a_pre.size () == 0 || b_pre.size () == 0 ) {
112+ array zero (0 , a_pre.dtype ());
113+ encoder.add_temporary (zero);
114+ fill_gpu (zero, out, s);
115+ return ;
116+ }
117+
118+ out.set_data (allocator::malloc (out.nbytes ()));
119+
120+ int M = a_pre.shape (-2 );
121+ int N = b_pre.shape (-1 );
122+ int K = a_pre.shape (-1 );
123+
124+ // Keep a vector with copies to be cleared in the completed buffer to release
125+ // the arrays
126+ auto [a_transposed, lda, a] = check_transpose (encoder, s, a_pre);
127+ auto [b_transposed, ldb, b] = check_transpose (encoder, s, b_pre);
128+
129+ gemm_and_bias (
130+ encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
131+ }
132+
117133void AddMM::eval_gpu (const std::vector<array>& inputs, array& out) {
118134 nvtx3::scoped_range r (" AddMM::eval_gpu" );
119135 auto & s = stream ();
@@ -136,6 +152,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
136152 auto [a_transposed, lda, a] = check_transpose (encoder, s, a_pre);
137153 auto [b_transposed, ldb, b] = check_transpose (encoder, s, b_pre);
138154
155+ // ///////////////////////////////////////////////////////////////////////////
156+ // Dispatch to GEMM with epilogue or AddMM
157+
158+ if (beta_ == 1 && c.strides (-1 ) == 1 && c.data_size () == out.shape (-1 )) {
159+ out.set_data (allocator::malloc (out.nbytes ()));
160+ gemm_and_bias (
161+ encoder,
162+ M,
163+ N,
164+ K,
165+ a_transposed,
166+ lda,
167+ b_transposed,
168+ ldb,
169+ out,
170+ a,
171+ b,
172+ c.data <void >());
173+ return ;
174+ }
175+
139176 int64_t ldc;
140177 {
141178 auto stx = c.strides ()[c.ndim () - 2 ];
@@ -177,7 +214,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
177214 }
178215
179216 // ///////////////////////////////////////////////////////////////////////////
180- // Invoke cublasLt
217+ // Invoke cublasLt with AddMM settings
181218
182219 CublasGemm gemm (
183220 cu::device (s.device ),
0 commit comments