Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions include/matx/core/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,19 @@ struct MemTracker {
return allocationMap.size();
}

auto deallocate(void *ptr) {
void update_stream(void *ptr, cudaStream_t stream) {
std::unique_lock lck(memory_mtx);
auto iter = allocationMap.find(ptr);
if (iter == allocationMap.end()) {
MATX_THROW(matxInvalidParameter, "Couldn't find pointer in allocation cache");
return;
}

iter->second.stream = stream;
}

template <typename StreamType>
auto deallocate_internal(void *ptr, StreamType st) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

std::unique_lock lck(memory_mtx);
Expand Down Expand Up @@ -122,7 +134,12 @@ struct MemTracker {
free(ptr);
break;
case MATX_ASYNC_DEVICE_MEMORY:
cudaFreeAsync(ptr, iter->second.stream);
if constexpr (std::is_same_v<no_stream_t, StreamType>) {
cudaFreeAsync(ptr, iter->second.stream);
}
else {
cudaFreeAsync(ptr, st.stream);
}
break;
default:
MATX_THROW(matxInvalidType, "Invalid memory type");
Expand All @@ -131,6 +148,17 @@ struct MemTracker {
allocationMap.erase(ptr);
}

struct no_stream_t{};
struct valid_stream_t { cudaStream_t stream; };

auto deallocate(void *ptr) {
deallocate_internal(ptr, no_stream_t{});
}

auto deallocate(void *ptr, cudaStream_t stream) {
deallocate_internal(ptr, valid_stream_t{stream});
}

void allocate(void **ptr, size_t bytes,
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
cudaStream_t stream = 0) {
Expand Down Expand Up @@ -335,6 +363,22 @@ inline void matxFree(void *ptr)
}


inline void matxFree(void *ptr, cudaStream_t stream)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
return GetAllocMap().deallocate(ptr, stream);
}

/**
Update the stream a pointer in the cache is using. This should be used when the call wants to use
memory that was allocated in stream A inside of stream B. The caller must ensure that the pointer
and stream being used are valid.
*/
inline void update_stream(void *ptr, cudaStream_t stream)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
GetAllocMap().update_stream(ptr, stream);
}

/**
* @brief Allocator following the PMR interface using the internal MatX allocator/deallocator
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/cub.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class matxCubPlan_t {
*/
~matxCubPlan_t()
{
matxFree(d_temp);
matxFree(d_temp, cudaStreamDefault);
}

template <typename Func>
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class matxEinsumHandle_t {

~matxEinsumHandle_t()
{
matxFree(workspace_);
matxFree(workspace_, cudaStreamDefault);
}

inline void Exec(OutputTensor &out, cudaStream_t stream, const InT... tensors)
Expand Down
4 changes: 2 additions & 2 deletions include/matx/transforms/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {

virtual ~matxFFTPlan_t() {
if (this->workspace_ != nullptr) {

matxFree(workspace_);
// Pass the default stream until we allow user-deletable caches
matxFree(workspace_, cudaStreamDefault);
this->workspace_ = nullptr;
}

Expand Down
8 changes: 4 additions & 4 deletions include/matx/transforms/inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ class matxInversePlan_t {
*/
~matxInversePlan_t()
{
matxFree(d_A_array);
matxFree(d_A_inv_array);
matxFree(d_pivot);
matxFree(d_info);
matxFree(d_A_array, cudaStreamDefault);
matxFree(d_A_inv_array, cudaStreamDefault);
matxFree(d_pivot, cudaStreamDefault);
matxFree(d_info, cudaStreamDefault);

cublasDestroy(handle);
}
Expand Down
6 changes: 3 additions & 3 deletions include/matx/transforms/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ class matxDnSolver_t {

virtual ~matxDnSolver_t()
{
matxFree(d_workspace);
matxFree(h_workspace);
matxFree(d_info);
matxFree(d_workspace, cudaStreamDefault);
matxFree(h_workspace, cudaStreamDefault);
matxFree(d_info, cudaStreamDefault);
cusolverDnDestroy(handle);
}

Expand Down