@@ -44,8 +44,6 @@ constexpr char const* CPU_RETRY_OOM_CLASS = "com/nvidia/spark/rapids/jni/CpuRetr
4444constexpr char const * CPU_SPLIT_AND_RETRY_OOM_CLASS =
4545 " com/nvidia/spark/rapids/jni/CpuSplitAndRetryOOM" ;
4646constexpr char const * THREAD_REG_CLASS = " com/nvidia/spark/rapids/jni/ThreadStateRegistry" ;
47- constexpr char const * IS_THREAD_BLOCKED = " isThreadBlocked" ;
48- constexpr char const * IS_THREAD_BLOCKED_SIG = " (J)Z" ;
4947constexpr char const * REMOVE_THREAD = " removeThread" ;
5048constexpr char const * REMOVE_THREAD_SIG = " (J)V" ;
5149
@@ -55,7 +53,6 @@ std::mutex jni_mutex;
5553bool is_jni_loaded = false ;
5654jclass ThreadStateRegistry_jclass;
5755jmethodID removeThread_method;
58- jmethodID isThreadBlocked_method;
5956
6057void cache_thread_reg_jni (JNIEnv* env)
6158{
@@ -67,9 +64,6 @@ void cache_thread_reg_jni(JNIEnv* env)
6764 removeThread_method = env->GetStaticMethodID (cls, REMOVE_THREAD, REMOVE_THREAD_SIG);
6865 if (removeThread_method == nullptr ) { return ; }
6966
70- isThreadBlocked_method = env->GetStaticMethodID (cls, IS_THREAD_BLOCKED, IS_THREAD_BLOCKED_SIG);
71- if (isThreadBlocked_method == nullptr ) { return ; }
72-
7367 // Convert local reference to global so it cannot be garbage collected.
7468 ThreadStateRegistry_jclass = static_cast <jclass>(env->NewGlobalRef (cls));
7569 if (ThreadStateRegistry_jclass == nullptr ) { return ; }
@@ -1120,10 +1114,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
11201114 return ret;
11211115 }
11221116
1123- void check_and_break_deadlocks ()
1117+ void check_and_break_deadlocks (std::unordered_set< long > const & java_blocked_thread_ids )
11241118 {
11251119 std::unique_lock<std::mutex> lock (state_mutex);
1126- check_and_update_for_bufn (lock);
1120+ check_and_update_for_bufn (lock, java_blocked_thread_ids );
11271121 }
11281122
11291123 bool cpu_prealloc (size_t const amount, bool const blocking)
@@ -1373,7 +1367,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
13731367 // Before we can wait it is possible that the throw didn't release anything
13741368 // and the other threads didn't get unblocked by this, so we need to
13751369 // check again to see if this was fixed or not.
1376- check_and_update_for_bufn (lock);
1370+ check_and_update_for_bufn_state_machine_only (lock);
13771371 // If that caused us to transition to a new state, then we need to adjust to it
13781372 // appropriately...
13791373 if (is_blocked (thread->second .state )) {
@@ -1696,127 +1690,55 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
16961690 break ;
16971691 default : break ;
16981692 }
1699- wake_next_highest_priority_blocked (lock, false , is_for_cpu);
1693+ wake_next_highest_priority_blocked (lock, is_for_cpu);
17001694 }
17011695 }
17021696
17031697 /* *
1704- * Wake the highest priority blocked (not BUFN) thread so it can make progress,
1705- * or the highest priority BUFN thread if all of the tasks are in some form of BUFN
1706- * and this was triggered by a free.
1698+ * Wake the highest priority blocked (not BUFN) thread so it can make progress
17071699 *
17081700 * This is typically called when a free happens, or an alloc succeeds.
1709- * @param is_from_free true if a free happen.
17101701 * @param is_for_cpu true if it was a CPU operation (free or alloc)
17111702 */
17121703 void wake_next_highest_priority_blocked (std::unique_lock<std::mutex> const & lock,
1713- bool const is_from_free,
17141704 bool const is_for_cpu)
17151705 {
17161706 // 1. Find the highest priority blocked thread, for the alloc that matches
17171707 thread_priority to_wake (-1 , -1 );
1718- bool is_to_wake_set = false ;
1719- for (auto const & [thread_d, t_state] : threads) {
1708+ full_thread_state* thread_to_wake = nullptr ;
1709+ for (auto & [thread_d, t_state] : threads) {
17201710 thread_state const & state = t_state.state ;
17211711 if (state == thread_state::THREAD_BLOCKED && is_for_cpu == t_state.is_cpu_alloc ) {
17221712 thread_priority current = t_state.priority ();
1723- if (!is_to_wake_set || to_wake < current) {
1713+ if (thread_to_wake == nullptr || to_wake < current) {
17241714 to_wake = current;
1725- is_to_wake_set = true ;
1715+ thread_to_wake = &t_state ;
17261716 }
17271717 }
17281718 }
17291719 // 2. wake up that thread
1730- long const thread_id_to_wake = to_wake.get_thread_id ();
1731- if (thread_id_to_wake > 0 ) {
1732- auto const thread = threads.find (thread_id_to_wake);
1733- if (thread != threads.end ()) {
1734- switch (thread->second .state ) {
1735- case thread_state::THREAD_BLOCKED:
1736- transition (thread->second , thread_state::THREAD_RUNNING);
1737- thread->second .wake_condition ->notify_all ();
1738- break ;
1739- default : {
1740- std::stringstream ss;
1741- ss << " internal error expected to only wake up blocked threads " << thread_id_to_wake
1742- << " " << as_str (thread->second .state );
1743- throw std::runtime_error (ss.str ());
1744- }
1745- }
1746- }
1747- } else if (is_from_free) {
1748- // 3. Otherwise look to see if we are in a BUFN deadlock state.
1749- //
1750- // Memory was freed and if all of the tasks are in a BUFN state,
1751- // then we want to wake up the highest priority one so it can make progress
1752- // instead of trying to split its input. But we only do this if it
1753- // is a different thread that is freeing memory from the one we want to wake up.
1754- // This is because if the threads are the same no new memory is being added
1755- // to what that task has access to and the task may never throw a retry and split.
1756- // Instead it would just keep retrying and freeing the same memory each time.
1757- std::map<long , long > pool_bufn_task_thread_count;
1758- std::map<long , long > pool_task_thread_count;
1759- std::unordered_set<long > bufn_task_ids;
1760- std::unordered_set<long > all_task_ids;
1761- is_in_deadlock (
1762- pool_bufn_task_thread_count, pool_task_thread_count, bufn_task_ids, all_task_ids, lock);
1763- bool const all_bufn = all_task_ids.size () == bufn_task_ids.size ();
1764- if (all_bufn) {
1765- thread_priority to_wake (-1 , -1 );
1766- bool is_to_wake_set = false ;
1767- for (auto const & [thread_id, t_state] : threads) {
1768- switch (t_state.state ) {
1769- case thread_state::THREAD_BUFN: {
1770- if (is_for_cpu == t_state.is_cpu_alloc ) {
1771- thread_priority current = t_state.priority ();
1772- if (!is_to_wake_set || to_wake < current) {
1773- to_wake = current;
1774- is_to_wake_set = true ;
1775- }
1776- }
1777- } break ;
1778- default : break ;
1779- }
1780- }
1781- // 4. Wake up the BUFN thread if we should
1782- if (is_to_wake_set) {
1783- long const thread_id_to_wake = to_wake.get_thread_id ();
1784- if (thread_id_to_wake > 0 ) {
1785- // Don't wake up yourself on a free. It is not adding more memory for this thread
1786- // to use on a retry and we might need a split instead to break a deadlock
1787- auto const this_id = static_cast <long >(pthread_self ());
1788- auto const thread = threads.find (thread_id_to_wake);
1789- if (thread != threads.end () && thread->first != this_id) {
1790- switch (thread->second .state ) {
1791- case thread_state::THREAD_BUFN:
1792- transition (thread->second , thread_state::THREAD_RUNNING);
1793- thread->second .wake_condition ->notify_all ();
1794- break ;
1795- case thread_state::THREAD_BUFN_WAIT:
1796- transition (thread->second , thread_state::THREAD_RUNNING);
1797- // no need to notify anyone, we will just retry without blocking...
1798- break ;
1799- case thread_state::THREAD_BUFN_THROW:
1800- // This should really never happen, this is a temporary state that is here only
1801- // while the lock is held, but just in case we don't want to mess it up, or throw
1802- // an exception.
1803- break ;
1804- default : {
1805- std::stringstream ss;
1806- ss << " internal error expected to only wake up blocked threads "
1807- << thread_id_to_wake << " " << as_str (thread->second .state );
1808- throw std::runtime_error (ss.str ());
1809- }
1810- }
1811- }
1812- }
1813- }
1814- }
1720+ if (thread_to_wake != nullptr ) {
1721+ transition (*thread_to_wake, thread_state::THREAD_RUNNING);
1722+ thread_to_wake->wake_condition ->notify_all ();
18151723 }
18161724 }
18171725
1818- bool is_thread_bufn_or_above (JNIEnv* env, full_thread_state const & state)
1726+ /* *
1727+ * Returns a boolean indicating if the thread is in BUFN (Blocked Until
1728+ * Further Notice) state or above. A thread is considered BUFN or above if:
1729+ * - It has pool_blocked set to true, OR
1730+ * - Its state is THREAD_BUFN, OR
1731+ * - Its thread ID is in the java_blocked_thread_ids set if it is provided.
1732+ *
1733+ * Threads in THREAD_BLOCKED state are NOT considered BUFN or above.
1734+ */
1735+ bool is_thread_bufn_or_above (
1736+ full_thread_state const & state,
1737+ std::optional<std::unordered_set<long >> const & java_blocked_thread_ids)
18191738 {
1739+ LOG_INFO (" is_thread_bufn_or_above: state: {}, java_blocked_thread_ids: {}" ,
1740+ state.thread_id ,
1741+ to_string (java_blocked_thread_ids));
18201742 bool ret = false ;
18211743 if (state.pool_blocked ) {
18221744 ret = true ;
@@ -1828,8 +1750,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
18281750 ret = true ;
18291751 break ;
18301752 default :
1831- ret = env->CallStaticBooleanMethod (
1832- ThreadStateRegistry_jclass, isThreadBlocked_method, state.thread_id );
1753+ if (java_blocked_thread_ids.has_value ()) {
1754+ ret = java_blocked_thread_ids.value ().contains (state.thread_id );
1755+ }
18331756 break ;
18341757 }
18351758 }
@@ -1853,18 +1776,23 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
18531776 return oss.str ();
18541777 }
18551778
1856- bool is_in_deadlock (std::map<long , long >& pool_bufn_task_thread_count,
1779+ template <typename SetType>
1780+ std::string to_string (std::optional<SetType> const & set, std::string const & separator = " ," )
1781+ {
1782+ if (set.has_value ()) {
1783+ return to_string (set.value (), separator);
1784+ } else {
1785+ return " {}" ;
1786+ }
1787+ }
1788+
1789+ bool is_in_deadlock (std::unique_lock<std::mutex> const & lock,
1790+ std::map<long , long >& pool_bufn_task_thread_count,
18571791 std::map<long , long >& pool_task_thread_count,
18581792 std::unordered_set<long >& bufn_task_ids,
18591793 std::unordered_set<long >& all_task_ids,
1860- std::unique_lock <std::mutex> const & lock )
1794+ std::optional <std::unordered_set< long >> const & java_blocked_thread_ids )
18611795 {
1862- JNIEnv* env = nullptr ;
1863- if (jvm->GetEnv (reinterpret_cast <void **>(&env), cudf::jni::MINIMUM_JNI_VERSION) != JNI_OK) {
1864- throw std::runtime_error (" Cloud not init JNI callbacks" );
1865- }
1866- cache_thread_reg_jni (env);
1867-
18681796 // If all of the tasks are blocked, then we are in a deadlock situation
18691797 // and we need to wake something up. In theory if any one thread is still
18701798 // doing something, then we are not deadlocked. But the problem is detecting
@@ -1897,7 +1825,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
18971825 long const task_id = t_state.task_id ;
18981826 if (task_id >= 0 ) {
18991827 all_task_ids.insert (task_id);
1900- bool const is_bufn_plus = is_thread_bufn_or_above (env, t_state );
1828+ bool const is_bufn_plus = is_thread_bufn_or_above (t_state, java_blocked_thread_ids );
19011829 if (is_bufn_plus) { bufn_task_ids.insert (task_id); }
19021830 if (is_bufn_plus || t_state.state == thread_state::THREAD_BLOCKED) {
19031831 blocked_task_ids.insert (task_id);
@@ -1918,7 +1846,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
19181846 }
19191847 }
19201848
1921- bool const is_bufn_plus = is_thread_bufn_or_above (env, t_state );
1849+ bool const is_bufn_plus = is_thread_bufn_or_above (t_state, java_blocked_thread_ids );
19221850 if (is_bufn_plus) {
19231851 for (auto const & task_id : t_state.pool_task_ids ) {
19241852 auto const it = pool_bufn_task_thread_count.find (task_id);
@@ -1974,19 +1902,44 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
19741902 " DETAIL" , -1 , -1 , thread_state::UNKNOWN, " States of all threads: {}" , get_threads_string ());
19751903 }
19761904
1905+ /* *
1906+ * This method is only called from code that only cares about the state machine state,
1907+ * and is happening very often. Our deadlock busting thread will invoke
1908+ * `check_and_update_for_bufn` directly, and pass along a set of java native thread ids
1909+ * that are blocked in the java state (Thread.getState).
1910+ *
1911+ * This split in invocation is done to make the critical sections faster, and leave
1912+ * deadlock busting to the deadlock thread.
1913+ */
1914+ void check_and_update_for_bufn_state_machine_only (const std::unique_lock<std::mutex>& lock)
1915+ {
1916+ // we pass nullopt because we are calling this method from a place in the code that has
1917+ // no java knowledge of blocked threads.
1918+ check_and_update_for_bufn (lock, /* java_blocked_thread_ids*/ std::nullopt );
1919+ }
1920+
19771921 /* *
19781922 * Check to see if any threads need to move to BUFN. This should be
19791923 * called when a task or shuffle thread becomes blocked so that we can
19801924 * check to see if one of them needs to become BUFN or do a split and rollback.
1925+ *
1926+ * If this method is being called from the deadlock busting thread, we will pass
1927+ * along a set of java native thread ids that are blocked in the java state (Thread.getState).
19811928 */
1982- void check_and_update_for_bufn (const std::unique_lock<std::mutex>& lock)
1929+ void check_and_update_for_bufn (
1930+ const std::unique_lock<std::mutex>& lock,
1931+ std::optional<std::unordered_set<long >> const & java_blocked_thread_ids)
19831932 {
19841933 std::map<long , long > pool_bufn_task_thread_count;
19851934 std::map<long , long > pool_task_thread_count;
19861935 std::unordered_set<long > bufn_task_ids;
19871936 std::unordered_set<long > all_task_ids;
1988- bool const need_to_break_deadlock = is_in_deadlock (
1989- pool_bufn_task_thread_count, pool_task_thread_count, bufn_task_ids, all_task_ids, lock);
1937+ bool const need_to_break_deadlock = is_in_deadlock (lock,
1938+ pool_bufn_task_thread_count,
1939+ pool_task_thread_count,
1940+ bufn_task_ids,
1941+ all_task_ids,
1942+ java_blocked_thread_ids);
19901943 if (need_to_break_deadlock) {
19911944 // Find the task thread with the lowest priority that is not already BUFN
19921945 thread_priority to_bufn (-1 , -1 );
@@ -2031,6 +1984,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
20311984 thread->second .wake_condition ->notify_all ();
20321985 }
20331986 }
1987+
20341988 // We now need a way to detect if we need to split the input and retry.
20351989 // This happens when all of the tasks are also blocked until
20361990 // further notice. So we are going to treat a task as blocked until
@@ -2154,7 +2108,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
21542108 // do not retry if the thread is not registered...
21552109 ret = false ;
21562110 }
2157- check_and_update_for_bufn (lock);
2111+ check_and_update_for_bufn_state_machine_only (lock);
21582112 return ret;
21592113 }
21602114
@@ -2222,7 +2176,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
22222176 }
22232177 }
22242178 }
2225- wake_next_highest_priority_blocked (lock, true , is_for_cpu);
2179+ wake_next_highest_priority_blocked (lock, is_for_cpu);
22262180 }
22272181
22282182 void do_deallocate (void * p, std::size_t size, rmm::cuda_stream_view stream) noexcept override
@@ -2585,13 +2539,16 @@ JNIEXPORT void JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_end
25852539}
25862540
25872541JNIEXPORT void JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_checkAndBreakDeadlocks (
2588- JNIEnv* env, jclass, jlong ptr)
2542+ JNIEnv* env, jclass, jlong ptr, jlongArray jblocked_thread_ids )
25892543{
25902544 JNI_NULL_CHECK (env, ptr, " resource_adaptor is null" , );
25912545 JNI_TRY
25922546 {
25932547 auto mr = reinterpret_cast <spark_resource_adaptor*>(ptr);
2594- mr->check_and_break_deadlocks ();
2548+ cudf::jni::native_jlongArray blocked_thread_ids (env, jblocked_thread_ids);
2549+ std::unordered_set<long > blocked_thread_ids_set (blocked_thread_ids.begin (),
2550+ blocked_thread_ids.end ());
2551+ mr->check_and_break_deadlocks (blocked_thread_ids_set);
25952552 }
25962553 JNI_CATCH (env, );
25972554}
0 commit comments