Skip to content

Commit 3d7e6f1

Browse files
committed
apply pr comments
1 parent 68a7ea6 commit 3d7e6f1

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

mlx/backend/cuda/device/hadamard.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ __device__ __forceinline__ void hadamard_radix_m(float* x);
1010

1111
template <int N>
1212
struct Pow2Log2 {
13-
static_assert(N > 0 && (N % 2 == 0), "N must be a power-of-two > 1.");
13+
static_assert(
14+
(N > 0) && ((N & (N - 1)) == 0),
15+
"N must be a positive power of two.");
1416
static constexpr int value = 1 + Pow2Log2<N / 2>::value;
1517
};
1618

mlx/backend/cuda/hadamard.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ void hadamard_mn_contiguous(
159159
args.append(num_transforms);
160160

161161
auto kernel = mod.get_kernel(n1_kernel_name);
162-
encoder.add_kernel_node(
163-
kernel, num_blocks, n1 / max_radix_1, 0, args.args());
162+
encoder.add_kernel_node_raw(
163+
kernel, num_blocks, n1 / max_radix_1, {}, 0, args.args());
164164
}
165165

166166
{
@@ -179,8 +179,8 @@ void hadamard_mn_contiguous(
179179
args.append(num_transforms);
180180

181181
auto kernel = mod.get_kernel(n2_kernel_name);
182-
encoder.add_kernel_node(
183-
kernel, num_blocks, n2 / max_radix_2, 0, args.args());
182+
encoder.add_kernel_node_raw(
183+
kernel, num_blocks, n2 / max_radix_2, {}, 0, args.args());
184184
}
185185

186186
if (m > 1) {
@@ -200,7 +200,8 @@ void hadamard_mn_contiguous(
200200
args.append(num_tasks);
201201

202202
auto kernel = mod.get_kernel(m_kernel_name);
203-
encoder.add_kernel_node(kernel, num_blocks, block_dim, 0, args.args());
203+
encoder.add_kernel_node_raw(
204+
kernel, num_blocks, block_dim, {}, 0, args.args());
204205
}
205206
}
206207

0 commit comments

Comments
 (0)