Skip to content

Commit cb67f2f

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
[socket-transfer]: IsLastSemaphore requires poisoning as well or else it will write to the atm after an error has been set.
PiperOrigin-RevId: 763624413
1 parent 9dcce05 commit cb67f2f

File tree

4 files changed

+42
-18
lines changed

4 files changed

+42
-18
lines changed

xla/python/transfer/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ xla_cc_test(
153153
"//xla/tsl/platform:env",
154154
"//xla/tsl/platform:statusor",
155155
"@com_google_absl//absl/log:check",
156+
"@com_google_absl//absl/status",
156157
"@com_google_absl//absl/status:statusor",
157158
"@com_google_absl//absl/synchronization",
158159
"@com_google_googletest//:gtest_main",

xla/python/transfer/streaming_ifrt.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class DmaDestination : public ChunkDestination {
221221
}
222222

223223
void Poison(absl::Status s) override {
224+
semaphore_.Poison();
224225
atm_->SetBufferError(buffer_index_, std::move(s));
225226
}
226227

xla/python/transfer/streaming_ifrt.h

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,27 +183,40 @@ class IsLastSemaphore {
183183
: guard_counter_(value), counter_(value) {}
184184

185185
template <typename T>
186-
auto DoWork(size_t value, T&& cb) -> decltype(cb(false)) {
187-
bool is_last = guard_counter_.fetch_sub(value) - value == 0;
188-
if (is_last && counter_.fetch_sub(value) - value != 0) {
189-
// Wait if we happen to slip in between guard_counter and counter.
186+
auto DoWork(size_t value, T&& cb) -> absl::Status {
187+
bool is_last;
188+
{
190189
absl::MutexLock l(&mu_);
191-
auto cond = [this]() { return counter_.load() == 0; };
192-
mu_.Await(absl::Condition(&cond));
190+
if (is_poisoned_) {
191+
return absl::OkStatus();
192+
}
193+
guard_counter_ -= value;
194+
is_last = guard_counter_ == 0;
195+
if (is_last) {
196+
// Wait if we happen to slip in between guard_counter and counter.
197+
auto cond = [this, value]() { return counter_ == value; };
198+
mu_.Await(absl::Condition(&cond));
199+
}
193200
}
194201
auto cleanup = absl::MakeCleanup([&]() {
195-
if (!is_last && (counter_.fetch_sub(value) - value) == 0) {
196-
// Wake any waiters.
197-
absl::MutexLock l(&mu_);
198-
}
202+
absl::MutexLock l(&mu_);
203+
counter_ -= value;
199204
});
200205
return cb(is_last);
201206
}
202207

208+
void Poison() {
209+
absl::MutexLock l(&mu_);
210+
is_poisoned_ = true;
211+
auto cond = [this]() { return counter_ == guard_counter_; };
212+
mu_.Await(absl::Condition(&cond));
213+
}
214+
203215
private:
204216
absl::Mutex mu_;
205-
std::atomic<ssize_t> guard_counter_;
206-
std::atomic<ssize_t> counter_;
217+
bool is_poisoned_ = false;
218+
ssize_t guard_counter_;
219+
ssize_t counter_;
207220
};
208221

209222
} // namespace internal

xla/python/transfer/streaming_ifrt_test.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525

2626
#include <gtest/gtest.h>
2727
#include "absl/log/check.h"
28+
#include "absl/status/status.h"
2829
#include "absl/status/statusor.h"
2930
#include "absl/synchronization/mutex.h"
3031
#include "absl/synchronization/notification.h"
@@ -151,9 +152,15 @@ TEST(PremappedCopierState, RoundTrip) {
151152
TEST(Semaphore, Basic) {
152153
internal::IsLastSemaphore semaphore(15);
153154
for (size_t i = 0; i < 10; ++i) {
154-
semaphore.DoWork(1, [&](bool is_last) { EXPECT_FALSE(is_last); });
155+
CHECK_OK(semaphore.DoWork(1, [&](bool is_last) -> absl::Status {
156+
EXPECT_FALSE(is_last);
157+
return absl::OkStatus();
158+
}));
155159
}
156-
semaphore.DoWork(5, [&](bool is_last) { EXPECT_TRUE(is_last); });
160+
CHECK_OK(semaphore.DoWork(5, [&](bool is_last) -> absl::Status {
161+
EXPECT_TRUE(is_last);
162+
return absl::OkStatus();
163+
}));
157164
}
158165

159166
TEST(Semaphore, Async) {
@@ -177,24 +184,26 @@ TEST(Semaphore, Async) {
177184
tsl::Env::Default()->StartThread({}, "t1", [&]() {
178185
for (size_t i = 0; i < 8; ++i) {
179186
thread_wait_flip(0);
180-
o_semaphore.DoWork(1, [&](bool is_last) {
187+
CHECK_OK(o_semaphore.DoWork(1, [&](bool is_last) -> absl::Status {
181188
thread_flip(0);
182189
EXPECT_FALSE(is_last);
183-
});
190+
return absl::OkStatus();
191+
}));
184192
}
185193
}));
186194
std::unique_ptr<tsl::Thread> t2(
187195
tsl::Env::Default()->StartThread({}, "t2", [&]() {
188196
for (size_t i = 0; i < 8; ++i) {
189197
thread_wait_flip(1);
190-
o_semaphore.DoWork(1, [&](bool is_last) {
198+
CHECK_OK(o_semaphore.DoWork(1, [&](bool is_last) -> absl::Status {
191199
thread_flip(1);
192200
if (i == 7) {
193201
EXPECT_TRUE(is_last);
194202
} else {
195203
EXPECT_FALSE(is_last);
196204
}
197-
});
205+
return absl::OkStatus();
206+
}));
198207
}
199208
}));
200209
}

0 commit comments

Comments
 (0)