Skip to content

Conversation

asi1024
Copy link
Member

@asi1024 asi1024 commented Jun 11, 2021

This PR supports nested tuple unpack in CuPy JIT.

Example:

@cupyx.jit.rawkernel()
def f(x, y, z):
    i = cupyx.jit.threadIdx.x
    (x[i], y[i]), z[i] = (z[i], x[i]), y[i]

x = cupy.full((5), 10, dtype=cupy.int32)
y = cupy.full((5), 20, dtype=cupy.int32)
z = cupy.full((5), 30, dtype=cupy.int32)
f((1,), (5,), (x, y, z))

Generated kernel:

extern "C" __global__ void f(CArray<int, 1, true, true> x, CArray<int, 1, true, true> y, CArray<int, 1, true, true> z) {
  unsigned int i;
  i = threadIdx.x;
  {
    thrust::tuple<thrust::tuple<int, int>, int> _temp0 = thrust::make_tuple(thrust::make_tuple(z[i], x[i]), y[i]);
    {
      thrust::tuple<int, int> _temp1 = thrust::get<0>(_temp0);
      x[i] = thrust::get<0>(_temp1);
      y[i] = thrust::get<1>(_temp1);
    }
    z[i] = thrust::get<1>(_temp0);
  }
}

Related discussion: #5293 (comment) (cc/ @eternalphane)

@asi1024
Copy link
Member Author

asi1024 commented Jun 11, 2021

Previously thrust::tie has been used to unpack tuple, but it cannot be used for nested tuples. So we decided to use thrust::get<i> to rvalues instead.

Example:

@cupyx.jit.rawkernel()
def f(x, y):
    i = cupyx.jit.threadIdx.x
    x[i], y[i] = y[i], x[i]

Previously, the following kernel has been generated:

extern "C" __global__ void f(CArray<int, 1, true, true> x, CArray<int, 1, true, true> y) {
  unsigned int i;
  i = threadIdx.x;
  thrust::tie(x[i], y[i]) = thrust::make_tuple(y[i], x[i]);
}

From this PR, the following kernel will be generated:

extern "C" __global__ void f(CArray<int, 1, true, true> x, CArray<int, 1, true, true> y) {
  unsigned int i;
  i = threadIdx.x;
  {
    thrust::tuple<int, int> _temp0 = thrust::make_tuple(y[i], x[i]);
    x[i] = thrust::get<0>(_temp0);
    y[i] = thrust::get<1>(_temp0);
  }
}

@emcastillo emcastillo self-assigned this Jun 12, 2021
@emcastillo emcastillo added the cat:feature New features/APIs label Jun 12, 2021
@emcastillo emcastillo added this to the v10.0.0a2 milestone Jun 12, 2021
@emcastillo
Copy link
Member

Jenkins, test this please

@asi1024 asi1024 force-pushed the jit-nested-ptrn-match branch from 4f03a94 to 69414b5 Compare June 12, 2021 05:12
@asi1024
Copy link
Member Author

asi1024 commented Jun 12, 2021

Sorry, I have forgotten remove debug logging.
Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 69414b5, target branch master) failed with status FAILURE.

@emcastillo
Copy link
Member

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit da50f65, target branch master) succeeded!

@emcastillo emcastillo added the st:test-and-merge (deprecated) Ready to merge after test pass. label Jun 14, 2021
@mergify mergify bot merged commit 0c7eb25 into cupy:master Jun 14, 2021
@asi1024 asi1024 deleted the jit-nested-ptrn-match branch June 14, 2021 03:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature New features/APIs prio:medium st:test-and-merge (deprecated) Ready to merge after test pass.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants