Skip to content

Commit 51014fe

Browse files
committed
Draft
1 parent 01b9faa commit 51014fe

File tree

5 files changed

+223
-0
lines changed

5 files changed

+223
-0
lines changed

CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17)
44
file(STRINGS version.txt TORCHVISION_VERSION)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
7+
option(WITH_MPS "Enable MPS support" OFF)
78
option(WITH_PNG "Enable features requiring LibPNG." ON)
89
option(WITH_JPEG "Enable features requiring LibJPEG." ON)
910
option(USE_PYTHON "Link to Python when building" OFF)
@@ -15,6 +16,11 @@ if(WITH_CUDA)
1516
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
1617
endif()
1718

19+
if(WITH_MPS)
20+
enable_language(OBJC OBJCXX)
21+
add_definitions(-DWITH_MPS)
22+
endif()
23+
1824
find_package(Torch REQUIRED)
1925

2026
if (WITH_PNG)
@@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP
7985
if(WITH_CUDA)
8086
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
8187
endif()
88+
if(WITH_MPS)
89+
list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps)
90+
endif()
8291

8392
FOREACH(DIR ${ALLOW_LISTED})
8493
file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*)

setup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,13 @@ def get_extensions():
137137
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
138138
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
139139
)
140+
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
140141

141142
print("Compiling extensions with following flags:")
142143
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
143144
print(f" FORCE_CUDA: {force_cuda}")
145+
force_mps = os.getenv("FORCE_MPS", "0") == "1"
146+
print(f" FORCE_MPS: {force_mps}")
144147
debug_mode = os.getenv("DEBUG", "0") == "1"
145148
print(f" DEBUG: {debug_mode}")
146149
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
@@ -202,6 +205,9 @@ def get_extensions():
202205
define_macros += [("WITH_HIP", None)]
203206
nvcc_flags = []
204207
extra_compile_args["nvcc"] = nvcc_flags
208+
elif torch.backends.mps.is_available() or force_mps:
209+
sources += source_mps
210+
define_macros += [("WITH_MPS", None)]
205211

206212
if sys.platform == "win32":
207213
define_macros += [("torchvision_EXPORTS", None)]

test/common_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ def needs_cuda(test_func):
133133

134134
return pytest.mark.needs_cuda(test_func)
135135

136+
def needs_mps(test_func):
137+
import pytest
138+
139+
return pytest.mark.needs_mps(test_func)
140+
136141

137142
def _create_data(height=3, width=3, channels=3, device="cpu"):
138143
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//#include <ATen/mps/MPSProfiler.h>
2+
#include <ATen/native/mps/OperationUtils.h>
3+
#include "vision_kernels.h"
4+
5+
namespace vision {
6+
namespace ops {
7+
8+
namespace {
9+
10+
at::Tensor nms_kernel(
11+
const at::Tensor& dets,
12+
const at::Tensor& scores,
13+
double iou_threshold) {
14+
15+
using namespace at::native::mps;
16+
TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor");
17+
TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor");
18+
19+
TORCH_CHECK(
20+
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
21+
TORCH_CHECK(
22+
dets.size(1) == 4,
23+
"boxes should have 4 elements in dimension 1, got ",
24+
dets.size(1));
25+
TORCH_CHECK(
26+
scores.dim() == 1,
27+
"scores should be a 1d tensor, got ",
28+
scores.dim(),
29+
"D");
30+
TORCH_CHECK(
31+
dets.size(0) == scores.size(0),
32+
"boxes and scores should have same number of elements in ",
33+
"dimension 0, got ",
34+
dets.size(0),
35+
" and ",
36+
scores.size(0))
37+
38+
//at::Tensor input = at::arange({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt);
39+
//at::Tensor other = at::arange({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt);
40+
//at::Tensor out = at::zeros({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt);
41+
42+
if (dets.numel() == 0) {
43+
return at::empty({0}, dets.options().dtype(at::kLong));
44+
}
45+
46+
auto order_t = std::get<1>(
47+
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
48+
auto dets_sorted = dets.index_select(0, order_t).contiguous();
49+
int dets_num = dets.size(0);
50+
float iou_threshold_f = static_cast<float>(iou_threshold);
51+
52+
//TODO: ceil_div
53+
//const int col_blocks = ceil_div(dets_num, threadsPerBlock);
54+
//at::Tensor mask =
55+
// at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
56+
at::Tensor mask =
57+
at::empty({dets_num}, dets.options().dtype(at::kLong));
58+
59+
id<MTLBuffer> inputBuffer = getMTLBufferStorage(dets_sorted);
60+
id<MTLBuffer> outputBuffer = getMTLBufferStorage(mask);
61+
id<MTLDevice> device = MPSDevice::getInstance()->device();
62+
MPSStream* mpsStream = getCurrentMPSStream();
63+
//const uint32_t nDim = iter.ndim();
64+
//constexpr uint32_t nOffsets = 3;
65+
const uint32_t numThreads = dets_num;
66+
dispatch_sync(mpsStream->queue(), ^() {
67+
@autoreleasepool {
68+
NSError* error = nil;
69+
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
70+
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
71+
72+
73+
const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type());
74+
id<MTLComputePipelineState> binaryPSO = mps::binaryPipelineState(device, kernel);
75+
76+
// this function call is a no-op if MPS Profiler is not enabled
77+
//getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other});
78+
79+
[computeEncoder setComputePipelineState:binaryPSO];
80+
[computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0];
81+
[computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1];
82+
[computeEncoder setBytes:&dets_num length:sizeof(int) atIndex:2];
83+
[computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3];
84+
85+
NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup;
86+
if (tgSize > numThreads) {
87+
tgSize = numThreads;
88+
}
89+
90+
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
91+
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
92+
93+
//getMPSProfiler().endProfileKernel(binaryPSO);
94+
}
95+
});
96+
return mask;
97+
98+
}
99+
100+
} // namespace
101+
102+
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
103+
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
104+
}
105+
106+
} // namespace ops
107+
} // namespace vision
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include <ATen/native/mps/OperationUtils.h>
2+
3+
namespace vision {
4+
namespace ops {
5+
6+
namespace mps {
7+
8+
static const char* METAL_VISION = R"VISION_METAL(
9+
10+
#include <metal_stdlib>
11+
using namespace metal;
12+
13+
template <typename T, typename scalar_t>
14+
bool IoU(
15+
constant T & a,
16+
constant T & b,
17+
scalar_t threshold) {
18+
auto xx1 = max(a.x, b.x);
19+
auto yy1 = max(a.y, b.y);
20+
auto xx2 = min(a.z, b.z);
21+
auto yy2 = min(a.w, b.w);
22+
auto w = max(static_cast<scalar_t>(0), xx2 - xx1);
23+
auto h = max(static_cast<scalar_t>(0), yy2 - yy1);
24+
auto inter = w * h;
25+
auto area_a = (a.z - a.x) * (a.w - a.y);
26+
auto area_b = (b.z - b.x) * (b.w - b.y);
27+
return (inter / (area_a + area_b - inter)) > threshold;
28+
}
29+
30+
template<typename T, typename scalar_t>
31+
kernel void nms(constant T * input [[buffer(0)]],
32+
device int64_t * out [[buffer(1)]],
33+
constant int & dets_num [[buffer(2)]],
34+
constant float & iou_threshold [[buffer(3)]],
35+
uint tid [[thread_position_in_grid]]) {
36+
int t = 0;
37+
for (int i = tid + 1; i < dets_num; i++){
38+
if (IoU<T, scalar_t>(input[tid], input[i], iou_threshold)){
39+
t |= static_cast<int>(1) << i;
40+
}
41+
}
42+
out[tid] = static_cast<int64_t>(t);
43+
}
44+
45+
#define REGISTER_NMS_OP(DTYPE) \
46+
template \
47+
[[host_name("nms_" #DTYPE)]] \
48+
kernel void nms<DTYPE ## 4, DTYPE>( \
49+
constant DTYPE ## 4 * input [[buffer(0)]], \
50+
device int64_t * out [[buffer(1)]], \
51+
constant int & dets_num [[buffer(2)]], \
52+
constant float & iou_threshold [[buffer(3)]], \
53+
uint tid [[thread_position_in_grid]]);
54+
55+
REGISTER_NMS_OP(float);
56+
REGISTER_NMS_OP(half);
57+
58+
)VISION_METAL";
59+
60+
static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
61+
static id<MTLLibrary> binaryLibrary = nil;
62+
if (binaryLibrary) {
63+
return binaryLibrary;
64+
}
65+
66+
NSError* error = nil;
67+
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
68+
[options setLanguageVersion:MTLLanguageVersion2_3];
69+
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
70+
options:options
71+
error:&error];
72+
TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]);
73+
return binaryLibrary;
74+
}
75+
76+
static id<MTLComputePipelineState> binaryPipelineState(id<MTLDevice> device, const std::string& kernel) {
77+
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
78+
id<MTLComputePipelineState> pso = psoCache[kernel];
79+
if (pso) {
80+
return pso;
81+
}
82+
83+
NSError* error = nil;
84+
id<MTLLibrary> binaryLib = compileBinaryOpsLibrary(device);
85+
id<MTLFunction> binaryFunc = [binaryLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
86+
TORCH_CHECK(binaryFunc, "Failed to create function state object for: ", kernel);
87+
pso = [device newComputePipelineStateWithFunction:binaryFunc error:&error];
88+
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
89+
90+
psoCache[kernel] = pso;
91+
return pso;
92+
}
93+
94+
}
95+
}
96+
} // namespace

0 commit comments

Comments
 (0)