Skip to content

Remove uses of tsl::Mutex inside tsl. #3296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ cc_library(
hdrs = ["blocking_counter.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":mutex",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@xla//xla/tsl/platform:logging",
],
)
Expand Down Expand Up @@ -182,12 +183,13 @@ cc_library(
srcs = ["path.cc"],
hdrs = ["path.h"],
deps = [
":mutex",
":scanner",
":str_util",
":strcat",
":stringpiece",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
"@xla//xla/tsl/platform:logging",
"@xla//xla/tsl/platform:types",
],
Expand Down Expand Up @@ -708,7 +710,8 @@ cc_library(
srcs = ["random.cc"],
hdrs = ["random.h"],
deps = [
":mutex",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
"@xla//xla/tsl/platform:types",
],
)
Expand Down Expand Up @@ -1165,10 +1168,11 @@ tsl_cc_test(
":blocking_counter",
":random",
":unbounded_work_queue",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
"@xla//xla/tsl/platform:env",
"@xla//xla/tsl/platform:env_impl",
"@xla//xla/tsl/platform:test",
],
)
Expand Down Expand Up @@ -1201,8 +1205,8 @@ cc_library(
name = "refcount",
hdrs = ["refcount.h"],
deps = [
":mutex",
":thread_annotations",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
"@xla//xla/tsl/platform:logging",
],
)
Expand Down Expand Up @@ -1322,11 +1326,13 @@ tsl_cc_test(
"notap", #TODO(b/245510532) : disabled due to flakiness.
],
deps = [
":mutex",
":platform_port",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
"@xla//xla/tsl/platform:env",
"@xla//xla/tsl/platform:env_impl",
"@xla//xla/tsl/platform:env_time",
"@xla//xla/tsl/platform:test",
],
)
Expand Down
21 changes: 11 additions & 10 deletions tsl/platform/blocking_counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ limitations under the License.

#include <atomic>

#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "xla/tsl/platform/logging.h"
#include "tsl/platform/mutex.h"

namespace tsl {

Expand All @@ -39,38 +40,38 @@ class BlockingCounter {
DCHECK_NE(((v + 2) & ~1), 0);
return; // either count has not dropped to 0, or waiter is not waiting
}
mutex_lock l(mu_);
absl::MutexLock l(&mu_);
DCHECK(!notified_);
notified_ = true;
cond_var_.notify_all();
cond_var_.SignalAll();
}

inline void Wait() {
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0) return;
mutex_lock l(mu_);
absl::MutexLock l(&mu_);
while (!notified_) {
cond_var_.wait(l);
cond_var_.Wait(&mu_);
}
}
// Wait for the specified time, return false iff the count has not dropped to
// zero before the timeout expired.
inline bool WaitFor(std::chrono::milliseconds ms) {
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0) return true;
mutex_lock l(mu_);
absl::MutexLock l(&mu_);
while (!notified_) {
const std::cv_status status = cond_var_.wait_for(l, ms);
if (status == std::cv_status::timeout) {
bool timeout = cond_var_.WaitWithTimeout(&mu_, absl::FromChrono(ms));
if (timeout) {
return false;
}
}
return true;
}

private:
mutex mu_;
condition_variable cond_var_;
absl::Mutex mu_;
absl::CondVar cond_var_;
std::atomic<int> state_; // low bit is waiter flag
bool notified_;
};
Expand Down
7 changes: 4 additions & 3 deletions tsl/platform/path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/const_init.h"
#include "absl/synchronization/mutex.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/types.h"
#include "tsl/platform/mutex.h"
#include "tsl/platform/scanner.h"
#include "tsl/platform/str_util.h"
#include "tsl/platform/strcat.h"
Expand Down Expand Up @@ -259,9 +260,9 @@ string CreateURI(absl::string_view scheme, absl::string_view host,

// Returns a unique number every time it is called.
int64_t UniqueId() {
static mutex mu(LINKER_INITIALIZED);
static absl::Mutex mu(absl::kConstInit);
static int64_t id = 0;
mutex_lock l(mu);
absl::MutexLock l(&mu);
return ++id;
}

Expand Down
93 changes: 46 additions & 47 deletions tsl/platform/port_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <condition_variable>

#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "xla/tsl/platform/env_time.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/threadpool.h"
#include "tsl/platform/cpu_info.h"
#include "tsl/platform/mem.h"
#include "tsl/platform/mutex.h"

namespace tsl {
namespace port {
Expand All @@ -45,72 +44,72 @@ TEST(Port, GetCurrentCPU) {
}

TEST(ConditionVariable, WaitForMilliseconds_Timeout) {
mutex m;
mutex_lock l(m);
condition_variable cv;
ConditionResult result = tsl::kCond_MaybeNotified;
absl::Mutex m;
absl::MutexLock l(&m);
absl::CondVar cv;
bool result = false;
time_t start = time(nullptr);
// Condition variables are subject to spurious wakeups on some platforms,
// so need to check for a timeout within a loop.
while (result == tsl::kCond_MaybeNotified) {
result = WaitForMilliseconds(&l, &cv, 3000);
while (!result) {
result = cv.WaitWithTimeout(&m, absl::Milliseconds(3000));
}
EXPECT_EQ(result, tsl::kCond_Timeout);
time_t finish = time(nullptr);
EXPECT_GE(finish - start, 3);
}

TEST(ConditionVariable, WaitForMilliseconds_Signalled) {
thread::ThreadPool pool(Env::Default(), "test", 1);
mutex m;
mutex_lock l(m);
condition_variable cv;
absl::Mutex m;
absl::MutexLock l(&m);
absl::CondVar cv;
time_t start = time(nullptr);
// Sleep for just 1 second then notify. We have a timeout of 3 secs,
// so the condition variable will notice the cv signal before the timeout.
pool.Schedule([&m, &cv]() {
Env::Default()->SleepForMicroseconds(1 * 1000 * 1000);
mutex_lock l(m);
cv.notify_all();
absl::MutexLock l(&m);
cv.SignalAll();
});
EXPECT_EQ(WaitForMilliseconds(&l, &cv, 3000), tsl::kCond_MaybeNotified);
EXPECT_FALSE(cv.WaitWithTimeout(&m, absl::Milliseconds(3000)));
time_t finish = time(nullptr);
EXPECT_LT(finish - start, 3);
}

TEST(ConditionalCriticalSections, AwaitWithDeadline_Timeout) {
bool always_false = false;
mutex m;
m.lock();
absl::Mutex m;
m.Lock();
time_t start = time(nullptr);
bool result =
m.AwaitWithDeadline(Condition(&always_false),
EnvTime::NowNanos() + 3 * EnvTime::kSecondsToNanos);
bool result = m.AwaitWithDeadline(
absl::Condition(&always_false),
absl::FromUnixNanos(EnvTime::NowNanos() + 3 * EnvTime::kSecondsToNanos));
time_t finish = time(nullptr);
m.unlock();
EXPECT_EQ(result, false);
m.Unlock();
EXPECT_FALSE(result);
EXPECT_GE(finish - start, 3);
}

TEST(ConditionalCriticalSections, AwaitWithDeadline_Woken) {
thread::ThreadPool pool(Env::Default(), "test", 1);
bool woken = false;
mutex m;
m.lock();
absl::Mutex m;
m.Lock();
time_t start = time(nullptr);
// Sleep for just 1 second then set the boolean. We have a timeout of 3
// secs, so the mutex implementation will notice the boolean state change
// before the timeout.
// secs, so the absl::Mutex implementation will notice the boolean state
// change before the timeout.
pool.Schedule([&m, &woken]() {
Env::Default()->SleepForMicroseconds(1 * 1000 * 1000);
m.lock();
m.Lock();
woken = true;
m.unlock();
m.Unlock();
});
bool result = m.AwaitWithDeadline(
Condition(&woken), EnvTime::NowNanos() + 3 * EnvTime::kSecondsToNanos);
absl::Condition(&woken),
absl::FromUnixNanos(EnvTime::NowNanos() + 3 * EnvTime::kSecondsToNanos));
time_t finish = time(nullptr);
m.unlock();
m.Unlock();
EXPECT_EQ(result, true);
EXPECT_LT(finish - start, 3);
}
Expand All @@ -134,48 +133,48 @@ TEST(ConditionalCriticalSections, Await_PingPong) {
thread::ThreadPool pool(Env::Default(), "test", 1);
bool ping_pong = false;
bool done = false;
mutex m;
absl::Mutex m;
pool.Schedule([&m, &ping_pong, &done]() {
m.lock();
m.Lock();
for (int i = 0; i != 1000; i++) {
m.Await(Condition(&ping_pong));
m.Await(absl::Condition(&ping_pong));
ping_pong = false;
}
done = true;
m.unlock();
m.Unlock();
});
m.lock();
m.Lock();
InvertClass invert(&ping_pong);
for (int i = 0; i != 1000; i++) {
m.Await(Condition(&Invert, &ping_pong));
m.Await(absl::Condition(&Invert, &ping_pong));
ping_pong = true;
}
m.Await(Condition(&done));
m.unlock();
m.Await(absl::Condition(&done));
m.Unlock();
}

TEST(ConditionalCriticalSections, Await_PingPongMethod) {
thread::ThreadPool pool(Env::Default(), "test", 1);
bool ping_pong = false;
bool done = false;
mutex m;
absl::Mutex m;
pool.Schedule([&m, &ping_pong, &done]() {
m.lock();
m.Lock();
for (int i = 0; i != 1000; i++) {
m.Await(Condition(&ping_pong));
m.Await(absl::Condition(&ping_pong));
ping_pong = false;
}
done = true;
m.unlock();
m.Unlock();
});
m.lock();
m.Lock();
InvertClass invert(&ping_pong);
for (int i = 0; i != 1000; i++) {
m.Await(Condition(&invert, &InvertClass::Value));
m.Await(absl::Condition(&invert, &InvertClass::Value));
ping_pong = true;
}
m.Await(Condition(&done));
m.unlock();
m.Await(absl::Condition(&done));
m.Unlock();
}

TEST(TestCPUFeature, TestFeature) {
Expand Down
11 changes: 6 additions & 5 deletions tsl/platform/random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ limitations under the License.
#include <memory>
#include <random>

#include "absl/base/const_init.h"
#include "absl/synchronization/mutex.h"
#include "xla/tsl/platform/types.h"
#include "tsl/platform/mutex.h"

namespace tsl {
namespace random {
Expand All @@ -35,8 +36,8 @@ std::mt19937_64 InitRngWithDefaultSeed() { return std::mt19937_64(); }

uint64 New64() {
static std::mt19937_64* rng = InitRngWithRandomSeed();
static mutex mu(LINKER_INITIALIZED);
mutex_lock l(mu);
static absl::Mutex mu(absl::kConstInit);
absl::MutexLock l(&mu);
return (*rng)();
}

Expand All @@ -48,8 +49,8 @@ uint64 ThreadLocalNew64() {

uint64 New64DefaultSeed() {
static std::mt19937_64 rng = InitRngWithDefaultSeed();
static mutex mu(LINKER_INITIALIZED);
mutex_lock l(mu);
static absl::Mutex mu(absl::kConstInit);
absl::MutexLock l(&mu);
return rng();
}

Expand Down
Loading