1 // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
10 #include "port/port.h"
11 #include "rocksdb/env.h"
12 #include "test_util/sync_point.h"
13 #include "test_util/testharness.h"
14 #include "test_util/testutil.h"
15 #include "util/autovector.h"
16 #include "util/thread_local.h"
18 namespace ROCKSDB_NAMESPACE
{
20 class ThreadLocalTest
: public testing::Test
{
22 ThreadLocalTest() : env_(Env::Default()) {}
30 Params(port::Mutex
* m
, port::CondVar
* c
, int* u
, int n
,
31 UnrefHandler handler
= nullptr)
53 class IDChecker
: public ThreadLocalPtr
{
55 static uint32_t PeekId() {
60 } // anonymous namespace
62 // Suppress false positive clang analyzer warnings.
63 #ifndef __clang_analyzer__
64 TEST_F(ThreadLocalTest
, UniqueIdTest
) {
66 port::CondVar
cv(&mu
);
68 uint32_t base_id
= IDChecker::PeekId();
69 // New ThreadLocal instance bumps id by 1
72 Params
p1(&mu
, &cv
, nullptr, 1u);
73 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
75 Params
p2(&mu
, &cv
, nullptr, 1u);
76 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 2u);
78 Params
p3(&mu
, &cv
, nullptr, 1u);
79 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
81 Params
p4(&mu
, &cv
, nullptr, 1u);
82 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 4u);
84 // id 3, 2, 1, 0 are in the free queue in order
85 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 0u);
88 Params
p1(&mu
, &cv
, nullptr, 1u);
89 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
91 Params
* p2
= new Params(&mu
, &cv
, nullptr, 1u);
92 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 2u);
94 Params
p3(&mu
, &cv
, nullptr, 1u);
95 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
98 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
99 // Now we have 3, 1 in queue
101 Params
p4(&mu
, &cv
, nullptr, 1u);
102 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
104 Params
p5(&mu
, &cv
, nullptr, 1u);
106 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 4u);
107 // After exit, id sequence in queue:
110 #endif // __clang_analyzer__
112 TEST_F(ThreadLocalTest
, SequentialReadWriteTest
) {
113 // global id list carries over 3, 1, 2, 0
114 uint32_t base_id
= IDChecker::PeekId();
117 port::CondVar
cv(&mu
);
118 Params
p(&mu
, &cv
, nullptr, 1);
122 ASSERT_GT(IDChecker::PeekId(), base_id
);
123 base_id
= IDChecker::PeekId();
125 auto func
= [](void* ptr
) {
126 auto& params
= *static_cast<Params
*>(ptr
);
128 ASSERT_TRUE(params
.tls1
.Get() == nullptr);
129 params
.tls1
.Reset(reinterpret_cast<int*>(1));
130 ASSERT_TRUE(params
.tls1
.Get() == reinterpret_cast<int*>(1));
131 params
.tls1
.Reset(reinterpret_cast<int*>(2));
132 ASSERT_TRUE(params
.tls1
.Get() == reinterpret_cast<int*>(2));
134 ASSERT_TRUE(params
.tls2
->Get() == nullptr);
135 params
.tls2
->Reset(reinterpret_cast<int*>(1));
136 ASSERT_TRUE(params
.tls2
->Get() == reinterpret_cast<int*>(1));
137 params
.tls2
->Reset(reinterpret_cast<int*>(2));
138 ASSERT_TRUE(params
.tls2
->Get() == reinterpret_cast<int*>(2));
141 ++(params
.completed
);
142 params
.cv
->SignalAll();
146 for (int iter
= 0; iter
< 1024; ++iter
) {
147 ASSERT_EQ(IDChecker::PeekId(), base_id
);
148 // Another new thread, read/write should not see value from previous thread
149 env_
->StartThread(func
, static_cast<void*>(&p
));
151 while (p
.completed
!= iter
+ 1) {
155 ASSERT_EQ(IDChecker::PeekId(), base_id
);
159 TEST_F(ThreadLocalTest
, ConcurrentReadWriteTest
) {
160 // global id list carries over 3, 1, 2, 0
161 uint32_t base_id
= IDChecker::PeekId();
165 port::CondVar
cv1(&mu1
);
166 Params
p1(&mu1
, &cv1
, nullptr, 16);
170 port::CondVar
cv2(&mu2
);
171 Params
p2(&mu2
, &cv2
, nullptr, 16);
175 auto func
= [](void* ptr
) {
176 auto& p
= *static_cast<Params
*>(ptr
);
179 // Size_T switches size along with the ptr size
180 // we want to cast to.
181 size_t own
= ++(p
.started
);
183 while (p
.started
!= p
.total
) {
188 // Let write threads write a different value from the read threads
193 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
194 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
196 auto* env
= Env::Default();
197 auto start
= env
->NowMicros();
199 p
.tls1
.Reset(reinterpret_cast<size_t*>(own
));
200 p
.tls2
->Reset(reinterpret_cast<size_t*>(own
+ 1));
202 while (env
->NowMicros() - start
< 1000 * 1000) {
203 for (int iter
= 0; iter
< 100000; ++iter
) {
204 ASSERT_TRUE(p
.tls1
.Get() == reinterpret_cast<size_t*>(own
));
205 ASSERT_TRUE(p
.tls2
->Get() == reinterpret_cast<size_t*>(own
+ 1));
207 p
.tls1
.Reset(reinterpret_cast<size_t*>(own
));
208 p
.tls2
->Reset(reinterpret_cast<size_t*>(own
+ 1));
219 // Initiate 2 instnaces: one keeps writing and one keeps reading.
220 // The read instance should not see data from the write instance.
221 // Each thread local copy of the value are also different from each
223 for (int th
= 0; th
< p1
.total
; ++th
) {
224 env_
->StartThread(func
, static_cast<void*>(&p1
));
226 for (int th
= 0; th
< p2
.total
; ++th
) {
227 env_
->StartThread(func
, static_cast<void*>(&p2
));
231 while (p1
.completed
!= p1
.total
) {
237 while (p2
.completed
!= p2
.total
) {
242 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
245 TEST_F(ThreadLocalTest
, Unref
) {
246 auto unref
= [](void* ptr
) {
247 auto& p
= *static_cast<Params
*>(ptr
);
253 // Case 0: no unref triggered if ThreadLocalPtr is never accessed
254 auto func0
= [](void* ptr
) {
255 auto& p
= *static_cast<Params
*>(ptr
);
260 while (p
.started
!= p
.total
) {
266 for (int th
= 1; th
<= 128; th
+= th
) {
268 port::CondVar
cv(&mu
);
270 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
272 for (int i
= 0; i
< p
.total
; ++i
) {
273 env_
->StartThread(func0
, static_cast<void*>(&p
));
276 ASSERT_EQ(unref_count
, 0);
279 // Case 1: unref triggered by thread exit
280 auto func1
= [](void* ptr
) {
281 auto& p
= *static_cast<Params
*>(ptr
);
286 while (p
.started
!= p
.total
) {
291 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
292 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
301 for (int th
= 1; th
<= 128; th
+= th
) {
303 port::CondVar
cv(&mu
);
305 ThreadLocalPtr
tls2(unref
);
306 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
309 for (int i
= 0; i
< p
.total
; ++i
) {
310 env_
->StartThread(func1
, static_cast<void*>(&p
));
315 // N threads x 2 ThreadLocal instance cleanup on thread exit
316 ASSERT_EQ(unref_count
, 2 * p
.total
);
319 // Case 2: unref triggered by ThreadLocal instance destruction
320 auto func2
= [](void* ptr
) {
321 auto& p
= *static_cast<Params
*>(ptr
);
326 while (p
.started
!= p
.total
) {
331 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
332 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
344 // Waiting for instruction to exit thread
345 while (p
.completed
!= 0) {
351 for (int th
= 1; th
<= 128; th
+= th
) {
353 port::CondVar
cv(&mu
);
355 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
356 p
.tls2
= new ThreadLocalPtr(unref
);
358 for (int i
= 0; i
< p
.total
; ++i
) {
359 env_
->StartThread(func2
, static_cast<void*>(&p
));
362 // Wait for all threads to finish using Params
364 while (p
.completed
!= p
.total
) {
369 // Now destroy one ThreadLocal instance
372 // instance destroy for N threads
373 ASSERT_EQ(unref_count
, p
.total
);
381 // additional N threads exit unref for the left instance
382 ASSERT_EQ(unref_count
, 2 * p
.total
);
386 TEST_F(ThreadLocalTest
, Swap
) {
388 tls
.Reset(reinterpret_cast<void*>(1));
389 ASSERT_EQ(reinterpret_cast<int64_t>(tls
.Swap(nullptr)), 1);
390 ASSERT_TRUE(tls
.Swap(reinterpret_cast<void*>(2)) == nullptr);
391 ASSERT_EQ(reinterpret_cast<int64_t>(tls
.Get()), 2);
392 ASSERT_EQ(reinterpret_cast<int64_t>(tls
.Swap(reinterpret_cast<void*>(3))), 2);
395 TEST_F(ThreadLocalTest
, Scrape
) {
396 auto unref
= [](void* ptr
) {
397 auto& p
= *static_cast<Params
*>(ptr
);
403 auto func
= [](void* ptr
) {
404 auto& p
= *static_cast<Params
*>(ptr
);
406 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
407 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
419 // Waiting for instruction to exit thread
420 while (p
.completed
!= 0) {
426 for (int th
= 1; th
<= 128; th
+= th
) {
428 port::CondVar
cv(&mu
);
430 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
431 p
.tls2
= new ThreadLocalPtr(unref
);
433 for (int i
= 0; i
< p
.total
; ++i
) {
434 env_
->StartThread(func
, static_cast<void*>(&p
));
437 // Wait for all threads to finish using Params
439 while (p
.completed
!= p
.total
) {
444 ASSERT_EQ(unref_count
, 0);
446 // Scrape all thread local data. No unref at thread
447 // exit or ThreadLocalPtr destruction
448 autovector
<void*> ptrs
;
449 p
.tls1
.Scrape(&ptrs
, nullptr);
450 p
.tls2
->Scrape(&ptrs
, nullptr);
459 ASSERT_EQ(unref_count
, 0);
463 TEST_F(ThreadLocalTest
, Fold
) {
464 auto unref
= [](void* ptr
) {
465 delete static_cast<std::atomic
<int64_t>*>(ptr
);
467 static const int kNumThreads
= 16;
468 static const int kItersPerThread
= 10;
470 port::CondVar
cv(&mu
);
471 Params
params(&mu
, &cv
, nullptr, kNumThreads
, unref
);
472 auto func
= [](void* ptr
) {
473 auto& p
= *static_cast<Params
*>(ptr
);
474 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
475 p
.tls1
.Reset(new std::atomic
<int64_t>(0));
477 for (int i
= 0; i
< kItersPerThread
; ++i
) {
478 static_cast<std::atomic
<int64_t>*>(p
.tls1
.Get())->fetch_add(1);
485 // Waiting for instruction to exit thread
486 while (p
.completed
!= 0) {
492 for (int th
= 0; th
< params
.total
; ++th
) {
493 env_
->StartThread(func
, static_cast<void*>(¶ms
));
496 // Wait for all threads to finish using Params
498 while (params
.completed
!= params
.total
) {
503 // Verify Fold() behavior
506 [](void* ptr
, void* res
) {
507 auto sum_ptr
= static_cast<int64_t*>(res
);
508 *sum_ptr
+= static_cast<std::atomic
<int64_t>*>(ptr
)->load();
511 ASSERT_EQ(sum
, kNumThreads
* kItersPerThread
);
515 params
.completed
= 0;
521 TEST_F(ThreadLocalTest
, CompareAndSwap
) {
523 ASSERT_TRUE(tls
.Swap(reinterpret_cast<void*>(1)) == nullptr);
524 void* expected
= reinterpret_cast<void*>(1);
526 ASSERT_TRUE(tls
.CompareAndSwap(reinterpret_cast<void*>(2), expected
));
527 expected
= reinterpret_cast<void*>(100);
528 // Fail Swap, still 2
529 ASSERT_TRUE(!tls
.CompareAndSwap(reinterpret_cast<void*>(2), expected
));
530 ASSERT_EQ(expected
, reinterpret_cast<void*>(2));
532 expected
= reinterpret_cast<void*>(2);
533 ASSERT_TRUE(tls
.CompareAndSwap(reinterpret_cast<void*>(3), expected
));
534 ASSERT_EQ(tls
.Get(), reinterpret_cast<void*>(3));
539 void* AccessThreadLocal(void* /*arg*/) {
540 TEST_SYNC_POINT("AccessThreadLocal:Start");
542 tlp
.Reset(new std::string("hello RocksDB"));
543 TEST_SYNC_POINT("AccessThreadLocal:End");
549 // The following test is disabled as it requires manual steps to run it
552 // Currently we have no way to acess SyncPoint w/o ASAN error when the
553 // child thread dies after the main thread dies. So if you manually enable
554 // this test and only see an ASAN error on SyncPoint, it means you pass the
556 TEST_F(ThreadLocalTest
, DISABLED_MainThreadDiesFirst
) {
557 ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
558 {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
559 {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});
561 // Triggers the initialization of singletons.
566 #endif // ROCKSDB_LITE
567 ROCKSDB_NAMESPACE::port::Thread
th(&AccessThreadLocal
, nullptr);
569 TEST_SYNC_POINT("MainThreadDiesFirst:End");
571 } catch (const std::system_error
& ex
) {
572 std::cerr
<< "Start thread: " << ex
.code() << std::endl
;
575 #endif // ROCKSDB_LITE
578 } // namespace ROCKSDB_NAMESPACE
580 int main(int argc
, char** argv
) {
581 ::testing::InitGoogleTest(&argc
, argv
);
582 return RUN_ALL_TESTS();