1 // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
10 #include "port/port.h"
11 #include "rocksdb/env.h"
12 #include "test_util/sync_point.h"
13 #include "test_util/testharness.h"
14 #include "test_util/testutil.h"
15 #include "util/autovector.h"
16 #include "util/thread_local.h"
18 namespace ROCKSDB_NAMESPACE
{
20 class ThreadLocalTest
: public testing::Test
{
22 ThreadLocalTest() : env_(Env::Default()) {}
30 Params(port::Mutex
* m
, port::CondVar
* c
, int* u
, int n
,
31 UnrefHandler handler
= nullptr)
53 class IDChecker
: public ThreadLocalPtr
{
55 static uint32_t PeekId() {
60 } // anonymous namespace
62 // Suppress false positive clang analyzer warnings.
63 #ifndef __clang_analyzer__
64 TEST_F(ThreadLocalTest
, UniqueIdTest
) {
66 port::CondVar
cv(&mu
);
68 uint32_t base_id
= IDChecker::PeekId();
69 // New ThreadLocal instance bumps id by 1
72 Params
p1(&mu
, &cv
, nullptr, 1u);
73 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
75 Params
p2(&mu
, &cv
, nullptr, 1u);
76 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 2u);
78 Params
p3(&mu
, &cv
, nullptr, 1u);
79 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
81 Params
p4(&mu
, &cv
, nullptr, 1u);
82 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 4u);
84 // id 3, 2, 1, 0 are in the free queue in order
85 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 0u);
88 Params
p1(&mu
, &cv
, nullptr, 1u);
89 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
91 Params
* p2
= new Params(&mu
, &cv
, nullptr, 1u);
92 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 2u);
94 Params
p3(&mu
, &cv
, nullptr, 1u);
95 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
98 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
99 // Now we have 3, 1 in queue
101 Params
p4(&mu
, &cv
, nullptr, 1u);
102 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
104 Params
p5(&mu
, &cv
, nullptr, 1u);
106 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 4u);
107 // After exit, id sequence in queue:
110 #endif // __clang_analyzer__
112 TEST_F(ThreadLocalTest
, SequentialReadWriteTest
) {
113 // global id list carries over 3, 1, 2, 0
114 uint32_t base_id
= IDChecker::PeekId();
117 port::CondVar
cv(&mu
);
118 Params
p(&mu
, &cv
, nullptr, 1);
122 auto func
= [](void* ptr
) {
123 auto& params
= *static_cast<Params
*>(ptr
);
125 ASSERT_TRUE(params
.tls1
.Get() == nullptr);
126 params
.tls1
.Reset(reinterpret_cast<int*>(1));
127 ASSERT_TRUE(params
.tls1
.Get() == reinterpret_cast<int*>(1));
128 params
.tls1
.Reset(reinterpret_cast<int*>(2));
129 ASSERT_TRUE(params
.tls1
.Get() == reinterpret_cast<int*>(2));
131 ASSERT_TRUE(params
.tls2
->Get() == nullptr);
132 params
.tls2
->Reset(reinterpret_cast<int*>(1));
133 ASSERT_TRUE(params
.tls2
->Get() == reinterpret_cast<int*>(1));
134 params
.tls2
->Reset(reinterpret_cast<int*>(2));
135 ASSERT_TRUE(params
.tls2
->Get() == reinterpret_cast<int*>(2));
138 ++(params
.completed
);
139 params
.cv
->SignalAll();
143 for (int iter
= 0; iter
< 1024; ++iter
) {
144 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
145 // Another new thread, read/write should not see value from previous thread
146 env_
->StartThread(func
, static_cast<void*>(&p
));
148 while (p
.completed
!= iter
+ 1) {
152 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 1u);
156 TEST_F(ThreadLocalTest
, ConcurrentReadWriteTest
) {
157 // global id list carries over 3, 1, 2, 0
158 uint32_t base_id
= IDChecker::PeekId();
162 port::CondVar
cv1(&mu1
);
163 Params
p1(&mu1
, &cv1
, nullptr, 16);
167 port::CondVar
cv2(&mu2
);
168 Params
p2(&mu2
, &cv2
, nullptr, 16);
172 auto func
= [](void* ptr
) {
173 auto& p
= *static_cast<Params
*>(ptr
);
176 // Size_T switches size along with the ptr size
177 // we want to cast to.
178 size_t own
= ++(p
.started
);
180 while (p
.started
!= p
.total
) {
185 // Let write threads write a different value from the read threads
190 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
191 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
193 auto* env
= Env::Default();
194 auto start
= env
->NowMicros();
196 p
.tls1
.Reset(reinterpret_cast<size_t*>(own
));
197 p
.tls2
->Reset(reinterpret_cast<size_t*>(own
+ 1));
199 while (env
->NowMicros() - start
< 1000 * 1000) {
200 for (int iter
= 0; iter
< 100000; ++iter
) {
201 ASSERT_TRUE(p
.tls1
.Get() == reinterpret_cast<size_t*>(own
));
202 ASSERT_TRUE(p
.tls2
->Get() == reinterpret_cast<size_t*>(own
+ 1));
204 p
.tls1
.Reset(reinterpret_cast<size_t*>(own
));
205 p
.tls2
->Reset(reinterpret_cast<size_t*>(own
+ 1));
216 // Initiate 2 instnaces: one keeps writing and one keeps reading.
217 // The read instance should not see data from the write instance.
218 // Each thread local copy of the value are also different from each
220 for (int th
= 0; th
< p1
.total
; ++th
) {
221 env_
->StartThread(func
, static_cast<void*>(&p1
));
223 for (int th
= 0; th
< p2
.total
; ++th
) {
224 env_
->StartThread(func
, static_cast<void*>(&p2
));
228 while (p1
.completed
!= p1
.total
) {
234 while (p2
.completed
!= p2
.total
) {
239 ASSERT_EQ(IDChecker::PeekId(), base_id
+ 3u);
242 TEST_F(ThreadLocalTest
, Unref
) {
243 auto unref
= [](void* ptr
) {
244 auto& p
= *static_cast<Params
*>(ptr
);
250 // Case 0: no unref triggered if ThreadLocalPtr is never accessed
251 auto func0
= [](void* ptr
) {
252 auto& p
= *static_cast<Params
*>(ptr
);
257 while (p
.started
!= p
.total
) {
263 for (int th
= 1; th
<= 128; th
+= th
) {
265 port::CondVar
cv(&mu
);
267 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
269 for (int i
= 0; i
< p
.total
; ++i
) {
270 env_
->StartThread(func0
, static_cast<void*>(&p
));
273 ASSERT_EQ(unref_count
, 0);
276 // Case 1: unref triggered by thread exit
277 auto func1
= [](void* ptr
) {
278 auto& p
= *static_cast<Params
*>(ptr
);
283 while (p
.started
!= p
.total
) {
288 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
289 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
298 for (int th
= 1; th
<= 128; th
+= th
) {
300 port::CondVar
cv(&mu
);
302 ThreadLocalPtr
tls2(unref
);
303 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
306 for (int i
= 0; i
< p
.total
; ++i
) {
307 env_
->StartThread(func1
, static_cast<void*>(&p
));
312 // N threads x 2 ThreadLocal instance cleanup on thread exit
313 ASSERT_EQ(unref_count
, 2 * p
.total
);
316 // Case 2: unref triggered by ThreadLocal instance destruction
317 auto func2
= [](void* ptr
) {
318 auto& p
= *static_cast<Params
*>(ptr
);
323 while (p
.started
!= p
.total
) {
328 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
329 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
341 // Waiting for instruction to exit thread
342 while (p
.completed
!= 0) {
348 for (int th
= 1; th
<= 128; th
+= th
) {
350 port::CondVar
cv(&mu
);
352 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
353 p
.tls2
= new ThreadLocalPtr(unref
);
355 for (int i
= 0; i
< p
.total
; ++i
) {
356 env_
->StartThread(func2
, static_cast<void*>(&p
));
359 // Wait for all threads to finish using Params
361 while (p
.completed
!= p
.total
) {
366 // Now destroy one ThreadLocal instance
369 // instance destroy for N threads
370 ASSERT_EQ(unref_count
, p
.total
);
378 // additional N threads exit unref for the left instance
379 ASSERT_EQ(unref_count
, 2 * p
.total
);
383 TEST_F(ThreadLocalTest
, Swap
) {
385 tls
.Reset(reinterpret_cast<void*>(1));
386 ASSERT_EQ(reinterpret_cast<int64_t>(tls
.Swap(nullptr)), 1);
387 ASSERT_TRUE(tls
.Swap(reinterpret_cast<void*>(2)) == nullptr);
388 ASSERT_EQ(reinterpret_cast<int64_t>(tls
.Get()), 2);
389 ASSERT_EQ(reinterpret_cast<int64_t>(tls
.Swap(reinterpret_cast<void*>(3))), 2);
392 TEST_F(ThreadLocalTest
, Scrape
) {
393 auto unref
= [](void* ptr
) {
394 auto& p
= *static_cast<Params
*>(ptr
);
400 auto func
= [](void* ptr
) {
401 auto& p
= *static_cast<Params
*>(ptr
);
403 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
404 ASSERT_TRUE(p
.tls2
->Get() == nullptr);
416 // Waiting for instruction to exit thread
417 while (p
.completed
!= 0) {
423 for (int th
= 1; th
<= 128; th
+= th
) {
425 port::CondVar
cv(&mu
);
427 Params
p(&mu
, &cv
, &unref_count
, th
, unref
);
428 p
.tls2
= new ThreadLocalPtr(unref
);
430 for (int i
= 0; i
< p
.total
; ++i
) {
431 env_
->StartThread(func
, static_cast<void*>(&p
));
434 // Wait for all threads to finish using Params
436 while (p
.completed
!= p
.total
) {
441 ASSERT_EQ(unref_count
, 0);
443 // Scrape all thread local data. No unref at thread
444 // exit or ThreadLocalPtr destruction
445 autovector
<void*> ptrs
;
446 p
.tls1
.Scrape(&ptrs
, nullptr);
447 p
.tls2
->Scrape(&ptrs
, nullptr);
456 ASSERT_EQ(unref_count
, 0);
460 TEST_F(ThreadLocalTest
, Fold
) {
461 auto unref
= [](void* ptr
) {
462 delete static_cast<std::atomic
<int64_t>*>(ptr
);
464 static const int kNumThreads
= 16;
465 static const int kItersPerThread
= 10;
467 port::CondVar
cv(&mu
);
468 Params
params(&mu
, &cv
, nullptr, kNumThreads
, unref
);
469 auto func
= [](void* ptr
) {
470 auto& p
= *static_cast<Params
*>(ptr
);
471 ASSERT_TRUE(p
.tls1
.Get() == nullptr);
472 p
.tls1
.Reset(new std::atomic
<int64_t>(0));
474 for (int i
= 0; i
< kItersPerThread
; ++i
) {
475 static_cast<std::atomic
<int64_t>*>(p
.tls1
.Get())->fetch_add(1);
482 // Waiting for instruction to exit thread
483 while (p
.completed
!= 0) {
489 for (int th
= 0; th
< params
.total
; ++th
) {
490 env_
->StartThread(func
, static_cast<void*>(¶ms
));
493 // Wait for all threads to finish using Params
495 while (params
.completed
!= params
.total
) {
500 // Verify Fold() behavior
503 [](void* ptr
, void* res
) {
504 auto sum_ptr
= static_cast<int64_t*>(res
);
505 *sum_ptr
+= static_cast<std::atomic
<int64_t>*>(ptr
)->load();
508 ASSERT_EQ(sum
, kNumThreads
* kItersPerThread
);
512 params
.completed
= 0;
518 TEST_F(ThreadLocalTest
, CompareAndSwap
) {
520 ASSERT_TRUE(tls
.Swap(reinterpret_cast<void*>(1)) == nullptr);
521 void* expected
= reinterpret_cast<void*>(1);
523 ASSERT_TRUE(tls
.CompareAndSwap(reinterpret_cast<void*>(2), expected
));
524 expected
= reinterpret_cast<void*>(100);
525 // Fail Swap, still 2
526 ASSERT_TRUE(!tls
.CompareAndSwap(reinterpret_cast<void*>(2), expected
));
527 ASSERT_EQ(expected
, reinterpret_cast<void*>(2));
529 expected
= reinterpret_cast<void*>(2);
530 ASSERT_TRUE(tls
.CompareAndSwap(reinterpret_cast<void*>(3), expected
));
531 ASSERT_EQ(tls
.Get(), reinterpret_cast<void*>(3));
536 void* AccessThreadLocal(void* /*arg*/) {
537 TEST_SYNC_POINT("AccessThreadLocal:Start");
539 tlp
.Reset(new std::string("hello RocksDB"));
540 TEST_SYNC_POINT("AccessThreadLocal:End");
546 // The following test is disabled as it requires manual steps to run it
549 // Currently we have no way to acess SyncPoint w/o ASAN error when the
550 // child thread dies after the main thread dies. So if you manually enable
551 // this test and only see an ASAN error on SyncPoint, it means you pass the
553 TEST_F(ThreadLocalTest
, DISABLED_MainThreadDiesFirst
) {
554 ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
555 {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
556 {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});
558 // Triggers the initialization of singletons.
563 #endif // ROCKSDB_LITE
564 ROCKSDB_NAMESPACE::port::Thread
th(&AccessThreadLocal
, nullptr);
566 TEST_SYNC_POINT("MainThreadDiesFirst:End");
568 } catch (const std::system_error
& ex
) {
569 std::cerr
<< "Start thread: " << ex
.code() << std::endl
;
572 #endif // ROCKSDB_LITE
575 } // namespace ROCKSDB_NAMESPACE
577 int main(int argc
, char** argv
) {
578 ::testing::InitGoogleTest(&argc
, argv
);
579 return RUN_ALL_TESTS();