Skip to content

Commit c4501b3

Browse files
committed
Only check for deadlocks in deadlock busting thread
Signed-off-by: Alessandro Bellina <[email protected]>
1 parent 0f60936 commit c4501b3

File tree

4 files changed

+104
-136
lines changed

4 files changed

+104
-136
lines changed

src/main/cpp/src/SparkResourceAdaptorJni.cpp

Lines changed: 81 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ constexpr char const* CPU_RETRY_OOM_CLASS = "com/nvidia/spark/rapids/jni/CpuRetr
4444
constexpr char const* CPU_SPLIT_AND_RETRY_OOM_CLASS =
4545
"com/nvidia/spark/rapids/jni/CpuSplitAndRetryOOM";
4646
constexpr char const* THREAD_REG_CLASS = "com/nvidia/spark/rapids/jni/ThreadStateRegistry";
47-
constexpr char const* IS_THREAD_BLOCKED = "isThreadBlocked";
48-
constexpr char const* IS_THREAD_BLOCKED_SIG = "(J)Z";
4947
constexpr char const* REMOVE_THREAD = "removeThread";
5048
constexpr char const* REMOVE_THREAD_SIG = "(J)V";
5149

@@ -55,7 +53,6 @@ std::mutex jni_mutex;
5553
bool is_jni_loaded = false;
5654
jclass ThreadStateRegistry_jclass;
5755
jmethodID removeThread_method;
58-
jmethodID isThreadBlocked_method;
5956

6057
void cache_thread_reg_jni(JNIEnv* env)
6158
{
@@ -67,9 +64,6 @@ void cache_thread_reg_jni(JNIEnv* env)
6764
removeThread_method = env->GetStaticMethodID(cls, REMOVE_THREAD, REMOVE_THREAD_SIG);
6865
if (removeThread_method == nullptr) { return; }
6966

70-
isThreadBlocked_method = env->GetStaticMethodID(cls, IS_THREAD_BLOCKED, IS_THREAD_BLOCKED_SIG);
71-
if (isThreadBlocked_method == nullptr) { return; }
72-
7367
// Convert local reference to global so it cannot be garbage collected.
7468
ThreadStateRegistry_jclass = static_cast<jclass>(env->NewGlobalRef(cls));
7569
if (ThreadStateRegistry_jclass == nullptr) { return; }
@@ -1120,10 +1114,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
11201114
return ret;
11211115
}
11221116

1123-
void check_and_break_deadlocks()
1117+
void check_and_break_deadlocks(std::unordered_set<long> const& java_blocked_thread_ids)
11241118
{
11251119
std::unique_lock<std::mutex> lock(state_mutex);
1126-
check_and_update_for_bufn(lock);
1120+
check_and_update_for_bufn(lock, java_blocked_thread_ids);
11271121
}
11281122

11291123
bool cpu_prealloc(size_t const amount, bool const blocking)
@@ -1373,7 +1367,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
13731367
// Before we can wait it is possible that the throw didn't release anything
13741368
// and the other threads didn't get unblocked by this, so we need to
13751369
// check again to see if this was fixed or not.
1376-
check_and_update_for_bufn(lock);
1370+
check_and_update_for_bufn_state_machine_only(lock);
13771371
// If that caused us to transition to a new state, then we need to adjust to it
13781372
// appropriately...
13791373
if (is_blocked(thread->second.state)) {
@@ -1696,127 +1690,55 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
16961690
break;
16971691
default: break;
16981692
}
1699-
wake_next_highest_priority_blocked(lock, false, is_for_cpu);
1693+
wake_next_highest_priority_blocked(lock, is_for_cpu);
17001694
}
17011695
}
17021696

17031697
/**
1704-
* Wake the highest priority blocked (not BUFN) thread so it can make progress,
1705-
* or the highest priority BUFN thread if all of the tasks are in some form of BUFN
1706-
* and this was triggered by a free.
1698+
* Wake the highest priority blocked (not BUFN) thread so it can make progress
17071699
*
17081700
* This is typically called when a free happens, or an alloc succeeds.
1709-
* @param is_from_free true if a free happen.
17101701
* @param is_for_cpu true if it was a CPU operation (free or alloc)
17111702
*/
17121703
void wake_next_highest_priority_blocked(std::unique_lock<std::mutex> const& lock,
1713-
bool const is_from_free,
17141704
bool const is_for_cpu)
17151705
{
17161706
// 1. Find the highest priority blocked thread, for the alloc that matches
17171707
thread_priority to_wake(-1, -1);
1718-
bool is_to_wake_set = false;
1719-
for (auto const& [thread_d, t_state] : threads) {
1708+
full_thread_state* thread_to_wake = nullptr;
1709+
for (auto& [thread_d, t_state] : threads) {
17201710
thread_state const& state = t_state.state;
17211711
if (state == thread_state::THREAD_BLOCKED && is_for_cpu == t_state.is_cpu_alloc) {
17221712
thread_priority current = t_state.priority();
1723-
if (!is_to_wake_set || to_wake < current) {
1713+
if (thread_to_wake == nullptr || to_wake < current) {
17241714
to_wake = current;
1725-
is_to_wake_set = true;
1715+
thread_to_wake = &t_state;
17261716
}
17271717
}
17281718
}
17291719
// 2. wake up that thread
1730-
long const thread_id_to_wake = to_wake.get_thread_id();
1731-
if (thread_id_to_wake > 0) {
1732-
auto const thread = threads.find(thread_id_to_wake);
1733-
if (thread != threads.end()) {
1734-
switch (thread->second.state) {
1735-
case thread_state::THREAD_BLOCKED:
1736-
transition(thread->second, thread_state::THREAD_RUNNING);
1737-
thread->second.wake_condition->notify_all();
1738-
break;
1739-
default: {
1740-
std::stringstream ss;
1741-
ss << "internal error expected to only wake up blocked threads " << thread_id_to_wake
1742-
<< " " << as_str(thread->second.state);
1743-
throw std::runtime_error(ss.str());
1744-
}
1745-
}
1746-
}
1747-
} else if (is_from_free) {
1748-
// 3. Otherwise look to see if we are in a BUFN deadlock state.
1749-
//
1750-
// Memory was freed and if all of the tasks are in a BUFN state,
1751-
// then we want to wake up the highest priority one so it can make progress
1752-
// instead of trying to split its input. But we only do this if it
1753-
// is a different thread that is freeing memory from the one we want to wake up.
1754-
// This is because if the threads are the same no new memory is being added
1755-
// to what that task has access to and the task may never throw a retry and split.
1756-
// Instead it would just keep retrying and freeing the same memory each time.
1757-
std::map<long, long> pool_bufn_task_thread_count;
1758-
std::map<long, long> pool_task_thread_count;
1759-
std::unordered_set<long> bufn_task_ids;
1760-
std::unordered_set<long> all_task_ids;
1761-
is_in_deadlock(
1762-
pool_bufn_task_thread_count, pool_task_thread_count, bufn_task_ids, all_task_ids, lock);
1763-
bool const all_bufn = all_task_ids.size() == bufn_task_ids.size();
1764-
if (all_bufn) {
1765-
thread_priority to_wake(-1, -1);
1766-
bool is_to_wake_set = false;
1767-
for (auto const& [thread_id, t_state] : threads) {
1768-
switch (t_state.state) {
1769-
case thread_state::THREAD_BUFN: {
1770-
if (is_for_cpu == t_state.is_cpu_alloc) {
1771-
thread_priority current = t_state.priority();
1772-
if (!is_to_wake_set || to_wake < current) {
1773-
to_wake = current;
1774-
is_to_wake_set = true;
1775-
}
1776-
}
1777-
} break;
1778-
default: break;
1779-
}
1780-
}
1781-
// 4. Wake up the BUFN thread if we should
1782-
if (is_to_wake_set) {
1783-
long const thread_id_to_wake = to_wake.get_thread_id();
1784-
if (thread_id_to_wake > 0) {
1785-
// Don't wake up yourself on a free. It is not adding more memory for this thread
1786-
// to use on a retry and we might need a split instead to break a deadlock
1787-
auto const this_id = static_cast<long>(pthread_self());
1788-
auto const thread = threads.find(thread_id_to_wake);
1789-
if (thread != threads.end() && thread->first != this_id) {
1790-
switch (thread->second.state) {
1791-
case thread_state::THREAD_BUFN:
1792-
transition(thread->second, thread_state::THREAD_RUNNING);
1793-
thread->second.wake_condition->notify_all();
1794-
break;
1795-
case thread_state::THREAD_BUFN_WAIT:
1796-
transition(thread->second, thread_state::THREAD_RUNNING);
1797-
// no need to notify anyone, we will just retry without blocking...
1798-
break;
1799-
case thread_state::THREAD_BUFN_THROW:
1800-
// This should really never happen, this is a temporary state that is here only
1801-
// while the lock is held, but just in case we don't want to mess it up, or throw
1802-
// an exception.
1803-
break;
1804-
default: {
1805-
std::stringstream ss;
1806-
ss << "internal error expected to only wake up blocked threads "
1807-
<< thread_id_to_wake << " " << as_str(thread->second.state);
1808-
throw std::runtime_error(ss.str());
1809-
}
1810-
}
1811-
}
1812-
}
1813-
}
1814-
}
1720+
if (thread_to_wake != nullptr) {
1721+
transition(*thread_to_wake, thread_state::THREAD_RUNNING);
1722+
thread_to_wake->wake_condition->notify_all();
18151723
}
18161724
}
18171725

1818-
bool is_thread_bufn_or_above(JNIEnv* env, full_thread_state const& state)
1726+
/**
1727+
* Returns a boolean indicating if the thread is in BUFN (Blocked Until
1728+
* Further Notice) state or above. A thread is considered BUFN or above if:
1729+
* - It has pool_blocked set to true, OR
1730+
* - Its state is THREAD_BUFN, OR
1731+
* - Its thread ID is in the java_blocked_thread_ids set if it is provided.
1732+
*
1733+
* Threads in THREAD_BLOCKED state are NOT considered BUFN or above.
1734+
*/
1735+
bool is_thread_bufn_or_above(
1736+
full_thread_state const& state,
1737+
std::optional<std::unordered_set<long>> const& java_blocked_thread_ids)
18191738
{
1739+
LOG_INFO("is_thread_bufn_or_above: state: {}, java_blocked_thread_ids: {}",
1740+
state.thread_id,
1741+
to_string(java_blocked_thread_ids));
18201742
bool ret = false;
18211743
if (state.pool_blocked) {
18221744
ret = true;
@@ -1828,8 +1750,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
18281750
ret = true;
18291751
break;
18301752
default:
1831-
ret = env->CallStaticBooleanMethod(
1832-
ThreadStateRegistry_jclass, isThreadBlocked_method, state.thread_id);
1753+
if (java_blocked_thread_ids.has_value()) {
1754+
ret = java_blocked_thread_ids.value().contains(state.thread_id);
1755+
}
18331756
break;
18341757
}
18351758
}
@@ -1853,18 +1776,23 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
18531776
return oss.str();
18541777
}
18551778

1856-
bool is_in_deadlock(std::map<long, long>& pool_bufn_task_thread_count,
1779+
template <typename SetType>
1780+
std::string to_string(std::optional<SetType> const& set, std::string const& separator = ",")
1781+
{
1782+
if (set.has_value()) {
1783+
return to_string(set.value(), separator);
1784+
} else {
1785+
return "{}";
1786+
}
1787+
}
1788+
1789+
bool is_in_deadlock(std::unique_lock<std::mutex> const& lock,
1790+
std::map<long, long>& pool_bufn_task_thread_count,
18571791
std::map<long, long>& pool_task_thread_count,
18581792
std::unordered_set<long>& bufn_task_ids,
18591793
std::unordered_set<long>& all_task_ids,
1860-
std::unique_lock<std::mutex> const& lock)
1794+
std::optional<std::unordered_set<long>> const& java_blocked_thread_ids)
18611795
{
1862-
JNIEnv* env = nullptr;
1863-
if (jvm->GetEnv(reinterpret_cast<void**>(&env), cudf::jni::MINIMUM_JNI_VERSION) != JNI_OK) {
1864-
throw std::runtime_error("Cloud not init JNI callbacks");
1865-
}
1866-
cache_thread_reg_jni(env);
1867-
18681796
// If all of the tasks are blocked, then we are in a deadlock situation
18691797
// and we need to wake something up. In theory if any one thread is still
18701798
// doing something, then we are not deadlocked. But the problem is detecting
@@ -1897,7 +1825,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
18971825
long const task_id = t_state.task_id;
18981826
if (task_id >= 0) {
18991827
all_task_ids.insert(task_id);
1900-
bool const is_bufn_plus = is_thread_bufn_or_above(env, t_state);
1828+
bool const is_bufn_plus = is_thread_bufn_or_above(t_state, java_blocked_thread_ids);
19011829
if (is_bufn_plus) { bufn_task_ids.insert(task_id); }
19021830
if (is_bufn_plus || t_state.state == thread_state::THREAD_BLOCKED) {
19031831
blocked_task_ids.insert(task_id);
@@ -1918,7 +1846,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
19181846
}
19191847
}
19201848

1921-
bool const is_bufn_plus = is_thread_bufn_or_above(env, t_state);
1849+
bool const is_bufn_plus = is_thread_bufn_or_above(t_state, java_blocked_thread_ids);
19221850
if (is_bufn_plus) {
19231851
for (auto const& task_id : t_state.pool_task_ids) {
19241852
auto const it = pool_bufn_task_thread_count.find(task_id);
@@ -1974,19 +1902,44 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
19741902
"DETAIL", -1, -1, thread_state::UNKNOWN, "States of all threads: {}", get_threads_string());
19751903
}
19761904

1905+
/**
1906+
* This method is only called from code that only cares about the state machine state,
1907+
* and is happening very often. Our deadlock busting thread will invoke
1908+
* `check_and_update_for_bufn` directly, and pass along a set of java native thread ids
1909+
* that are blocked in the java state (Thread.getState).
1910+
*
1911+
* This split in invocation is done to make the critical sections faster, and leave
1912+
* deadlock busting to the deadlock thread.
1913+
*/
1914+
void check_and_update_for_bufn_state_machine_only(const std::unique_lock<std::mutex>& lock)
1915+
{
1916+
// we pass nullopt because we are calling this method from a place in the code that has
1917+
// no java knowledge of blocked threads.
1918+
check_and_update_for_bufn(lock, /*java_blocked_thread_ids*/ std::nullopt);
1919+
}
1920+
19771921
/**
19781922
* Check to see if any threads need to move to BUFN. This should be
19791923
* called when a task or shuffle thread becomes blocked so that we can
19801924
* check to see if one of them needs to become BUFN or do a split and rollback.
1925+
*
1926+
* If this method is being called from the deadlock busting thread, we will pass
1927+
* along a set of java native thread ids that are blocked in the java state (Thread.getState).
19811928
*/
1982-
void check_and_update_for_bufn(const std::unique_lock<std::mutex>& lock)
1929+
void check_and_update_for_bufn(
1930+
const std::unique_lock<std::mutex>& lock,
1931+
std::optional<std::unordered_set<long>> const& java_blocked_thread_ids)
19831932
{
19841933
std::map<long, long> pool_bufn_task_thread_count;
19851934
std::map<long, long> pool_task_thread_count;
19861935
std::unordered_set<long> bufn_task_ids;
19871936
std::unordered_set<long> all_task_ids;
1988-
bool const need_to_break_deadlock = is_in_deadlock(
1989-
pool_bufn_task_thread_count, pool_task_thread_count, bufn_task_ids, all_task_ids, lock);
1937+
bool const need_to_break_deadlock = is_in_deadlock(lock,
1938+
pool_bufn_task_thread_count,
1939+
pool_task_thread_count,
1940+
bufn_task_ids,
1941+
all_task_ids,
1942+
java_blocked_thread_ids);
19901943
if (need_to_break_deadlock) {
19911944
// Find the task thread with the lowest priority that is not already BUFN
19921945
thread_priority to_bufn(-1, -1);
@@ -2031,6 +1984,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
20311984
thread->second.wake_condition->notify_all();
20321985
}
20331986
}
1987+
20341988
// We now need a way to detect if we need to split the input and retry.
20351989
// This happens when all of the tasks are also blocked until
20361990
// further notice. So we are going to treat a task as blocked until
@@ -2154,7 +2108,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
21542108
// do not retry if the thread is not registered...
21552109
ret = false;
21562110
}
2157-
check_and_update_for_bufn(lock);
2111+
check_and_update_for_bufn_state_machine_only(lock);
21582112
return ret;
21592113
}
21602114

@@ -2222,7 +2176,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
22222176
}
22232177
}
22242178
}
2225-
wake_next_highest_priority_blocked(lock, true, is_for_cpu);
2179+
wake_next_highest_priority_blocked(lock, is_for_cpu);
22262180
}
22272181

22282182
void do_deallocate(void* p, std::size_t size, rmm::cuda_stream_view stream) noexcept override
@@ -2585,13 +2539,16 @@ JNIEXPORT void JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_end
25852539
}
25862540

25872541
JNIEXPORT void JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_checkAndBreakDeadlocks(
2588-
JNIEnv* env, jclass, jlong ptr)
2542+
JNIEnv* env, jclass, jlong ptr, jlongArray jblocked_thread_ids)
25892543
{
25902544
JNI_NULL_CHECK(env, ptr, "resource_adaptor is null", );
25912545
JNI_TRY
25922546
{
25932547
auto mr = reinterpret_cast<spark_resource_adaptor*>(ptr);
2594-
mr->check_and_break_deadlocks();
2548+
cudf::jni::native_jlongArray blocked_thread_ids(env, jblocked_thread_ids);
2549+
std::unordered_set<long> blocked_thread_ids_set(blocked_thread_ids.begin(),
2550+
blocked_thread_ids.end());
2551+
mr->check_and_break_deadlocks(blocked_thread_ids_set);
25952552
}
25962553
JNI_CATCH(env, );
25972554
}

0 commit comments

Comments
 (0)