-
-
Notifications
You must be signed in to change notification settings - Fork 955
Support nested tuple unpack in CuPy JIT #5332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Previously Example: @cupyx.jit.rawkernel()
def f(x, y):
i = cupyx.jit.threadIdx.x
x[i], y[i] = y[i], x[i] Previously, the following kernel has been generated: extern "C" __global__ void f(CArray<int, 1, true, true> x, CArray<int, 1, true, true> y) {
unsigned int i;
i = threadIdx.x;
thrust::tie(x[i], y[i]) = thrust::make_tuple(y[i], x[i]);
} From this PR, the following kernel will be generated: extern "C" __global__ void f(CArray<int, 1, true, true> x, CArray<int, 1, true, true> y) {
unsigned int i;
i = threadIdx.x;
{
thrust::tuple<int, int> _temp0 = thrust::make_tuple(y[i], x[i]);
x[i] = thrust::get<0>(_temp0);
y[i] = thrust::get<1>(_temp0);
}
} |
Jenkins, test this please |
4f03a94
to
69414b5
Compare
Sorry, I have forgotten remove debug logging. |
Jenkins CI test (for commit 69414b5, target branch master) failed with status FAILURE. |
Jenkins, test this please |
Jenkins CI test (for commit da50f65, target branch master) succeeded! |
This PR supports nested tuple unpack in CuPy JIT.
Example:
Generated kernel:
Related discussion: #5293 (comment) (cc/ @eternalphane)