1 // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
14 #include "db/db_impl.h"
15 #include "db/write_callback.h"
16 #include "rocksdb/db.h"
17 #include "rocksdb/write_batch.h"
18 #include "port/port.h"
19 #include "util/random.h"
20 #include "util/sync_point.h"
21 #include "util/testharness.h"
27 class WriteCallbackTest
: public testing::Test
{
32 dbname
= test::PerThreadDBPath("write_callback_testdb");
36 class WriteCallbackTestWriteCallback1
: public WriteCallback
{
38 bool was_called
= false;
40 Status
Callback(DB
*db
) override
{
43 // Make sure db is a DBImpl
44 DBImpl
* db_impl
= dynamic_cast<DBImpl
*> (db
);
45 if (db_impl
== nullptr) {
46 return Status::InvalidArgument("");
52 bool AllowWriteBatching() override
{ return true; }
55 class WriteCallbackTestWriteCallback2
: public WriteCallback
{
57 Status
Callback(DB
* /*db*/) override
{ return Status::Busy(); }
58 bool AllowWriteBatching() override
{ return true; }
61 class MockWriteCallback
: public WriteCallback
{
63 bool should_fail_
= false;
64 bool allow_batching_
= false;
65 std::atomic
<bool> was_called_
{false};
67 MockWriteCallback() {}
69 MockWriteCallback(const MockWriteCallback
& other
) {
70 should_fail_
= other
.should_fail_
;
71 allow_batching_
= other
.allow_batching_
;
72 was_called_
.store(other
.was_called_
.load());
75 Status
Callback(DB
* /*db*/) override
{
76 was_called_
.store(true);
78 return Status::Busy();
84 bool AllowWriteBatching() override
{ return allow_batching_
; }
87 TEST_F(WriteCallbackTest
, WriteWithCallbackTest
) {
89 WriteOP(bool should_fail
= false) { callback_
.should_fail_
= should_fail
; }
91 void Put(const string
& key
, const string
& val
) {
92 kvs_
.push_back(std::make_pair(key
, val
));
93 write_batch_
.Put(key
, val
);
99 callback_
.was_called_
.store(false);
102 MockWriteCallback callback_
;
103 WriteBatch write_batch_
;
104 std::vector
<std::pair
<string
, string
>> kvs_
;
107 // In each scenario we'll launch multiple threads to write.
108 // The size of each array equals to number of threads, and
109 // each boolean in it denote whether callback of corresponding
110 // thread should succeed or fail.
111 std::vector
<std::vector
<WriteOP
>> write_scenarios
= {
118 {false, false, false},
120 {false, true, false},
122 {true, false, false, false, false},
123 {false, false, false, false, true},
124 {false, false, true, false, true},
127 for (auto& seq_per_batch
: {true, false}) {
128 for (auto& two_queues
: {true, false}) {
129 for (auto& allow_parallel
: {true, false}) {
130 for (auto& allow_batching
: {true, false}) {
131 for (auto& enable_WAL
: {true, false}) {
132 for (auto& enable_pipelined_write
: {true, false}) {
133 for (auto& write_group
: write_scenarios
) {
135 options
.create_if_missing
= true;
136 options
.allow_concurrent_memtable_write
= allow_parallel
;
137 options
.enable_pipelined_write
= enable_pipelined_write
;
138 options
.two_write_queues
= two_queues
;
139 if (options
.enable_pipelined_write
&& seq_per_batch
) {
140 // This combination is not supported
143 if (options
.enable_pipelined_write
&& options
.two_write_queues
) {
144 // This combination is not supported
148 ReadOptions read_options
;
152 DestroyDB(dbname
, options
);
154 DBOptions
db_options(options
);
155 ColumnFamilyOptions
cf_options(options
);
156 std::vector
<ColumnFamilyDescriptor
> column_families
;
157 column_families
.push_back(
158 ColumnFamilyDescriptor(kDefaultColumnFamilyName
, cf_options
));
159 std::vector
<ColumnFamilyHandle
*> handles
;
161 DBImpl::Open(db_options
, dbname
, column_families
, &handles
,
162 &db
, seq_per_batch
, true /* batch_per_txn */);
164 assert(handles
.size() == 1);
167 db_impl
= dynamic_cast<DBImpl
*>(db
);
168 ASSERT_TRUE(db_impl
);
170 // Writers that have called JoinBatchGroup.
171 std::atomic
<uint64_t> threads_joining(0);
172 // Writers that have linked to the queue
173 std::atomic
<uint64_t> threads_linked(0);
174 // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
175 std::atomic
<uint64_t> threads_verified(0);
177 std::atomic
<uint64_t> seq(db_impl
->GetLatestSequenceNumber());
178 ASSERT_EQ(db_impl
->GetLatestSequenceNumber(), 0);
180 rocksdb::SyncPoint::GetInstance()->SetCallBack(
181 "WriteThread::JoinBatchGroup:Start", [&](void*) {
182 uint64_t cur_threads_joining
= threads_joining
.fetch_add(1);
183 // Wait for the last joined writer to link to the queue.
184 // In this way the writers link to the queue one by one.
185 // This allows us to confidently detect the first writer
186 // who increases threads_linked as the leader.
187 while (threads_linked
.load() < cur_threads_joining
) {
191 // Verification once writers call JoinBatchGroup.
192 rocksdb::SyncPoint::GetInstance()->SetCallBack(
193 "WriteThread::JoinBatchGroup:Wait", [&](void* arg
) {
194 uint64_t cur_threads_linked
= threads_linked
.fetch_add(1);
195 bool is_leader
= false;
196 bool is_last
= false;
199 is_leader
= (cur_threads_linked
== 0);
200 is_last
= (cur_threads_linked
== write_group
.size() - 1);
203 auto* writer
= reinterpret_cast<WriteThread::Writer
*>(arg
);
206 ASSERT_TRUE(writer
->state
==
207 WriteThread::State::STATE_GROUP_LEADER
);
209 ASSERT_TRUE(writer
->state
==
210 WriteThread::State::STATE_INIT
);
213 // (meta test) the first WriteOP should indeed be the first
214 // and the last should be the last (all others can be out of
217 ASSERT_TRUE(writer
->callback
->Callback(nullptr).ok() ==
218 !write_group
.front().callback_
.should_fail_
);
219 } else if (is_last
) {
220 ASSERT_TRUE(writer
->callback
->Callback(nullptr).ok() ==
221 !write_group
.back().callback_
.should_fail_
);
224 threads_verified
.fetch_add(1);
225 // Wait here until all verification in this sync-point
226 // callback finish for all writers.
227 while (threads_verified
.load() < write_group
.size()) {
231 rocksdb::SyncPoint::GetInstance()->SetCallBack(
232 "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg
) {
234 auto* writer
= reinterpret_cast<WriteThread::Writer
*>(arg
);
236 if (!allow_batching
) {
237 // no batching so everyone should be a leader
238 ASSERT_TRUE(writer
->state
==
239 WriteThread::State::STATE_GROUP_LEADER
);
240 } else if (!allow_parallel
) {
241 ASSERT_TRUE(writer
->state
==
242 WriteThread::State::STATE_COMPLETED
||
243 (enable_pipelined_write
&&
246 STATE_MEMTABLE_WRITER_LEADER
));
250 std::atomic
<uint32_t> thread_num(0);
251 std::atomic
<char> dummy_key(0);
253 // Each write thread create a random write batch and write to DB
254 // with a write callback.
255 std::function
<void()> write_with_callback_func
= [&]() {
256 uint32_t i
= thread_num
.fetch_add(1);
259 // leaders gotta lead
260 while (i
> 0 && threads_verified
.load() < 1) {
264 while (i
== write_group
.size() - 1 &&
265 threads_verified
.load() < write_group
.size() - 1) {
268 auto& write_op
= write_group
.at(i
);
270 write_op
.callback_
.allow_batching_
= allow_batching
;
273 for (uint32_t j
= 0; j
< rnd
.Next() % 50; j
++) {
275 char my_key
= dummy_key
.fetch_add(1);
277 string
skey(5, my_key
);
278 string
sval(10, my_key
);
279 write_op
.Put(skey
, sval
);
281 if (!write_op
.callback_
.should_fail_
&& !seq_per_batch
) {
285 if (!write_op
.callback_
.should_fail_
&& seq_per_batch
) {
289 WriteOptions woptions
;
290 woptions
.disableWAL
= !enable_WAL
;
291 woptions
.sync
= enable_WAL
;
294 class PublishSeqCallback
: public PreReleaseCallback
{
296 PublishSeqCallback(DBImpl
* db_impl_in
)
297 : db_impl_(db_impl_in
) {}
298 virtual Status
Callback(SequenceNumber last_seq
,
299 bool /*not used*/) override
{
300 db_impl_
->SetLastPublishedSequence(last_seq
);
304 } publish_seq_callback(db_impl
);
305 // seq_per_batch requires a natural batch separator or Noop
306 WriteBatchInternal::InsertNoop(&write_op
.write_batch_
);
307 const size_t ONE_BATCH
= 1;
308 s
= db_impl
->WriteImpl(
309 woptions
, &write_op
.write_batch_
, &write_op
.callback_
,
310 nullptr, 0, false, nullptr, ONE_BATCH
,
311 two_queues
? &publish_seq_callback
: nullptr);
313 s
= db_impl
->WriteWithCallback(
314 woptions
, &write_op
.write_batch_
, &write_op
.callback_
);
317 if (write_op
.callback_
.should_fail_
) {
318 ASSERT_TRUE(s
.IsBusy());
324 rocksdb::SyncPoint::GetInstance()->EnableProcessing();
327 std::vector
<port::Thread
> threads
;
328 for (uint32_t i
= 0; i
< write_group
.size(); i
++) {
329 threads
.emplace_back(write_with_callback_func
);
331 for (auto& t
: threads
) {
335 rocksdb::SyncPoint::GetInstance()->DisableProcessing();
339 for (auto& w
: write_group
) {
340 ASSERT_TRUE(w
.callback_
.was_called_
.load());
341 for (auto& kvp
: w
.kvs_
) {
342 if (w
.callback_
.should_fail_
) {
344 db
->Get(read_options
, kvp
.first
, &value
).IsNotFound());
346 ASSERT_OK(db
->Get(read_options
, kvp
.first
, &value
));
347 ASSERT_EQ(value
, kvp
.second
);
352 ASSERT_EQ(seq
.load(), db_impl
->TEST_GetLastVisibleSequence());
355 DestroyDB(dbname
, options
);
365 TEST_F(WriteCallbackTest
, WriteCallBackTest
) {
367 WriteOptions write_options
;
368 ReadOptions read_options
;
373 DestroyDB(dbname
, options
);
375 options
.create_if_missing
= true;
376 Status s
= DB::Open(options
, dbname
, &db
);
379 db_impl
= dynamic_cast<DBImpl
*> (db
);
380 ASSERT_TRUE(db_impl
);
384 wb
.Put("a", "value.a");
387 // Test a simple Write
388 s
= db
->Write(write_options
, &wb
);
391 s
= db
->Get(read_options
, "a", &value
);
393 ASSERT_EQ("value.a", value
);
395 // Test WriteWithCallback
396 WriteCallbackTestWriteCallback1 callback1
;
399 wb2
.Put("a", "value.a2");
401 s
= db_impl
->WriteWithCallback(write_options
, &wb2
, &callback1
);
403 ASSERT_TRUE(callback1
.was_called
);
405 s
= db
->Get(read_options
, "a", &value
);
407 ASSERT_EQ("value.a2", value
);
409 // Test WriteWithCallback for a callback that fails
410 WriteCallbackTestWriteCallback2 callback2
;
413 wb3
.Put("a", "value.a3");
415 s
= db_impl
->WriteWithCallback(write_options
, &wb3
, &callback2
);
418 s
= db
->Get(read_options
, "a", &value
);
420 ASSERT_EQ("value.a2", value
);
423 DestroyDB(dbname
, options
);
426 } // namespace rocksdb
428 int main(int argc
, char** argv
) {
429 ::testing::InitGoogleTest(&argc
, argv
);
430 return RUN_ALL_TESTS();
436 int main(int /*argc*/, char** /*argv*/) {
438 "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
442 #endif // !ROCKSDB_LITE