Skip to content

Commit 87ab293

Browse files
pbelevichfacebook-github-bot
authored andcommitted
Custom RNG DispatchKey (pytorch#32325)
Summary: Pull Request resolved: pytorch#32325 The purpose of this PR is to enable PyTorch dispatching on `at::Generator*` parameters and demonstrate how it can be used in cpp extensions to implement custom RNG. 1. `CustomRNGKeyId` value added to DispatchKey enum and `DispatchKeySet key_set_` added to `at::Generator` 2. The overloaded `operator()(at::Generator* gen)` added to MultiDispatchKeySet. 3. The existing CPUGenerator and CUDAGenerator class are supplied with CPUTensorId and CUDATensorId dispatch keys 4. The implementation of CPU's `cauchy_kernel`(as an example, because it's already moved to ATen) was templatized and moved to `ATen/native/cpu/DistributionTemplates.h` to make it available for cpp extensions 5. Minor CMake changes to make native/cpu tensors available for cpp extensions 6. RegisterCustomRNG test that demonstrates how CustomCPUGenerator class can be implemented and how custom_rng_cauchy_ native function can be registered to handle Tensor::cauchy_ calls. Test Plan: Imported from OSS Differential Revision: D19604558 Pulled By: pbelevich fbshipit-source-id: 2619f14076cee5742094a0be832d8530bba72728
1 parent ddd4581 commit 87ab293

File tree

14 files changed

+129
-24
lines changed

14 files changed

+129
-24
lines changed

aten/src/ATen/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ if(USE_ROCM)
3434
endif()
3535

3636
# NB: If you edit these globs, you'll have to update setup.py package_data as well
37-
FILE(GLOB base_h "*.h" "detail/*.h" "cpu/*.h")
37+
FILE(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec256/*.h" "quantized/*.h")
3838
FILE(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp")
3939
add_subdirectory(core)
4040
FILE(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
@@ -65,6 +65,7 @@ FILE(GLOB native_quantized_cpp
6565
"native/quantized/cpu/*.cpp")
6666
FILE(GLOB native_h "native/*.h")
6767
FILE(GLOB native_quantized_h "native/quantized/*.h" "native/quantized/cpu/*.h")
68+
FILE(GLOB native_cpu_h "native/cpu/*.h")
6869

6970
FILE(GLOB native_cuda_cu "native/cuda/*.cu")
7071
FILE(GLOB native_cuda_cpp "native/cuda/*.cpp")
@@ -379,7 +380,7 @@ INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
379380
if(INTERN_BUILD_MOBILE)
380381
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS})
381382
else()
382-
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_h} ${native_quantized_h} ${cuda_h} ${cudnn_h} ${hip_h} ${miopen_h})
383+
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_h} ${native_cpu_h} ${native_quantized_h} ${cuda_h} ${cudnn_h} ${hip_h} ${miopen_h})
383384
endif()
384385

385386
# https://stackoverflow.com/questions/11096471/how-can-i-install-a-hierarchy-of-files-using-cmake

aten/src/ATen/CPUGenerator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ inline uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {
4848
* CPUGenerator class implementation
4949
*/
5050
CPUGenerator::CPUGenerator(uint64_t seed_in)
51-
: Generator{Device(DeviceType::CPU)},
51+
: Generator{Device(DeviceType::CPU), DispatchKeySet(c10::DispatchKey::CPUTensorId)},
5252
engine_{seed_in},
5353
next_float_normal_sample_{c10::optional<float>()},
5454
next_double_normal_sample_{c10::optional<double>()} { }

aten/src/ATen/CPUGenerator.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include <ATen/core/Generator.h>
44
#include <ATen/core/MT19937RNGEngine.h>
5-
#include <ATen/core/PhiloxRNGEngine.h>
65
#include <c10/util/Optional.h>
76

87
namespace at {

aten/src/ATen/core/DistributionsHelper.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
#include <math.h>
77
#endif
88

9-
#include <ATen/CPUGenerator.h>
109
#include <ATen/core/Array.h>
10+
#include <c10/util/Half.h>
11+
#include <c10/util/Optional.h>
1112
#include <type_traits>
1213
#include <limits>
1314
#include <cmath>
@@ -85,7 +86,8 @@ struct uniform_real_distribution {
8586
b = b_in;
8687
}
8788

88-
inline dist_acctype<T> operator()(at::CPUGenerator* generator){
89+
template <typename RNG>
90+
inline dist_acctype<T> operator()(RNG* generator){
8991
dist_acctype<T> x;
9092
if(std::is_same<T, double>::value) {
9193
x = (generator->random64() & DOUBLE_MASK) * DOUBLE_DIVISOR;
@@ -115,7 +117,8 @@ struct normal_distribution {
115117
stdv = stdv_in;
116118
}
117119

118-
inline dist_acctype<T> operator()(at::CPUGenerator* generator){
120+
template <typename RNG>
121+
inline dist_acctype<T> operator()(RNG* generator){
119122
dist_acctype<T> ret;
120123
// return cached values if available
121124
if (std::is_same<T, double>::value) {
@@ -166,7 +169,8 @@ struct bernoulli_distribution {
166169
p = p_in;
167170
}
168171

169-
inline int operator()(at::CPUGenerator* generator) {
172+
template <typename RNG>
173+
inline int operator()(RNG* generator) {
170174
uniform_real_distribution<T> uniform(0.0, 1.0);
171175
return uniform(generator) < p;
172176
}
@@ -186,7 +190,8 @@ struct geometric_distribution {
186190
p = p_in;
187191
}
188192

189-
inline int operator()(at::CPUGenerator* generator) {
193+
template <typename RNG>
194+
inline int operator()(RNG* generator) {
190195
uniform_real_distribution<T> uniform(0.0, 1.0);
191196
dist_acctype<T> sample = uniform(generator);
192197
return static_cast<int>(::log(static_cast<T>(1.0)-sample) / ::log(p)) + 1;
@@ -206,7 +211,8 @@ struct exponential_distribution {
206211
lambda = lambda_in;
207212
}
208213

209-
inline T operator()(at::CPUGenerator* generator) {
214+
template <typename RNG>
215+
inline T operator()(RNG* generator) {
210216
// Follows numpy exponential for the case when lambda is zero.
211217
if (lambda == static_cast<T>(0.0)) {
212218
return static_cast<T>(0.0);
@@ -231,7 +237,8 @@ struct cauchy_distribution {
231237
sigma = sigma_in;
232238
}
233239

234-
inline T operator()(at::CPUGenerator* generator) {
240+
template <typename RNG>
241+
inline T operator()(RNG* generator) {
235242
uniform_real_distribution<T> uniform(0.0, 1.0);
236243
return median + sigma * ::tan(static_cast<T>(M_PI) * (uniform(generator)-static_cast<T>(0.5)));
237244
}
@@ -255,7 +262,8 @@ struct lognormal_distribution {
255262
stdv = stdv_in;
256263
}
257264

258-
inline T operator()(at::CPUGenerator* generator){
265+
template<typename RNG>
266+
inline T operator()(RNG* generator){
259267
normal_distribution<T> normal(mean, stdv);
260268
return ::exp(normal(generator));
261269
}

aten/src/ATen/core/Generator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ namespace at {
1212
/**
1313
* Generator class implementation
1414
*/
15-
Generator::Generator(Device device_in) : device_{device_in} {}
15+
Generator::Generator(Device device_in, DispatchKeySet key_set)
16+
: device_{device_in}, key_set_(key_set) {}
1617

1718
/**
1819
* Clone this generator. Note that clone() is the only

aten/src/ATen/core/Generator.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <c10/util/Exception.h>
1212
#include <c10/util/C++17.h>
1313
#include <c10/core/Device.h>
14+
#include <c10/core/DispatchKeySet.h>
1415

1516
/**
1617
* Note [Generator]
@@ -54,7 +55,7 @@ constexpr uint64_t default_rng_seed_val = 67280421310721;
5455

5556
struct CAFFE2_API Generator {
5657
// Constructors
57-
Generator(Device device_in);
58+
Generator(Device device_in, DispatchKeySet key_set);
5859

5960
// Delete all copy and move assignment in favor of clone()
6061
// method
@@ -74,8 +75,11 @@ struct CAFFE2_API Generator {
7475
// See Note [Acquire lock when using random generators]
7576
std::mutex mutex_;
7677

78+
DispatchKeySet key_set() const { return key_set_; }
79+
7780
private:
7881
Device device_;
82+
DispatchKeySet key_set_;
7983
virtual Generator* clone_impl() const = 0;
8084
};
8185

aten/src/ATen/core/dispatch/DispatchKeyExtractor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ namespace detail {
5252
ts = ts | x.key_set();
5353
}
5454
}
55+
void operator()(at::Generator* gen) {
56+
if (gen != nullptr) {
57+
ts = ts | gen->key_set();
58+
}
59+
}
5560
template <typename T>
5661
void operator()(const T& x) {
5762
// do nothing

aten/src/ATen/cuda/CUDAGenerator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ std::shared_ptr<CUDAGenerator> createCUDAGenerator(DeviceIndex device_index) {
7373
* CUDAGenerator class implementation
7474
*/
7575
CUDAGenerator::CUDAGenerator(DeviceIndex device_index)
76-
: Generator{Device(DeviceType::CUDA, device_index)} { }
76+
: Generator{Device(DeviceType::CUDA, device_index),
77+
DispatchKeySet(c10::DispatchKey::CUDATensorId)} { }
7778

7879
/**
7980
* Sets the seed to be used by curandStatePhilox4_32_10
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#include <ATen/Dispatch.h>
4+
#include <ATen/core/DistributionsHelper.h>
5+
#include <ATen/native/TensorIterator.h>
6+
#include <ATen/native/cpu/Loops.h>
7+
#include <mutex>
8+
9+
namespace at { namespace native { namespace templates {
10+
11+
template<typename RNG>
12+
void cauchy_kernel(TensorIterator& iter, double median, double sigma, RNG* generator) {
13+
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "cauchy_cpu", [&]() {
14+
std::lock_guard<std::mutex> lock(generator->mutex_);
15+
cpu_serial_kernel(iter, [median, sigma, generator]() -> scalar_t {
16+
at::cauchy_distribution<double> cauchy(median, sigma);
17+
return (scalar_t)cauchy(generator);
18+
});
19+
});
20+
}
21+
22+
}}}

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <ATen/native/cpu/zmath.h>
2020
#include <ATen/native/Math.h>
2121
#include <ATen/core/DistributionsHelper.h>
22+
#include <ATen/native/cpu/DistributionTemplates.h>
2223

2324
#if AT_MKL_ENABLED()
2425
#include <mkl.h>
@@ -253,14 +254,8 @@ static void clamp_min_kernel(TensorIterator& iter, Scalar min_scalar) {
253254
}
254255

255256
static void cauchy_kernel(TensorIterator& iter, double median, double sigma, Generator* gen) {
256-
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "cauchy_cpu", [&]() {
257-
CPUGenerator* generator = get_generator_or_default<CPUGenerator>(gen, detail::getDefaultCPUGenerator());
258-
std::lock_guard<std::mutex> lock(generator->mutex_);
259-
cpu_serial_kernel(iter, [median, sigma, generator]() -> scalar_t {
260-
at::cauchy_distribution<double> cauchy(median, sigma);
261-
return (scalar_t)cauchy(generator);
262-
});
263-
});
257+
CPUGenerator* generator = get_generator_or_default<CPUGenerator>(gen, detail::getDefaultCPUGenerator());
258+
templates::cauchy_kernel(iter, median, sigma, generator);
264259
}
265260

266261
#if !AT_MKL_ENABLED()

0 commit comments

Comments
 (0)