1 // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
2 // This source code is licensed under both the GPLv2 (found in the
3 // COPYING file in the root directory) and Apache 2.0 License
4 // (found in the LICENSE.Apache file in the root directory).
6 #include "utilities/transactions/lock/point/point_lock_tracker.h"
8 namespace ROCKSDB_NAMESPACE
{
12 class TrackedKeysColumnFamilyIterator
13 : public LockTracker::ColumnFamilyIterator
{
15 explicit TrackedKeysColumnFamilyIterator(const TrackedKeys
& keys
)
16 : tracked_keys_(keys
), it_(keys
.begin()) {}
18 bool HasNext() const override
{ return it_
!= tracked_keys_
.end(); }
20 ColumnFamilyId
Next() override
{ return (it_
++)->first
; }
23 const TrackedKeys
& tracked_keys_
;
24 TrackedKeys::const_iterator it_
;
27 class TrackedKeysIterator
: public LockTracker::KeyIterator
{
29 TrackedKeysIterator(const TrackedKeys
& keys
, ColumnFamilyId id
)
30 : key_infos_(keys
.at(id
)), it_(key_infos_
.begin()) {}
32 bool HasNext() const override
{ return it_
!= key_infos_
.end(); }
34 const std::string
& Next() override
{ return (it_
++)->first
; }
37 const TrackedKeyInfos
& key_infos_
;
38 TrackedKeyInfos::const_iterator it_
;
43 void PointLockTracker::Track(const PointLockRequest
& r
) {
44 auto& keys
= tracked_keys_
[r
.column_family_id
];
45 #ifdef __cpp_lib_unordered_map_try_emplace
46 // use c++17's try_emplace if available, to avoid rehashing the key
47 // in case it is not already in the map
48 auto result
= keys
.try_emplace(r
.key
, r
.seq
);
49 auto it
= result
.first
;
50 if (!result
.second
&& r
.seq
< it
->second
.seq
) {
51 // Now tracking this key with an earlier sequence number
52 it
->second
.seq
= r
.seq
;
55 auto it
= keys
.find(r
.key
);
56 if (it
== keys
.end()) {
57 auto result
= keys
.emplace(r
.key
, TrackedKeyInfo(r
.seq
));
59 } else if (r
.seq
< it
->second
.seq
) {
60 // Now tracking this key with an earlier sequence number
61 it
->second
.seq
= r
.seq
;
64 // else we do not update the seq. The smaller the tracked seq, the stronger it
65 // the guarantee since it implies from the seq onward there has not been a
66 // concurrent update to the key. So we update the seq if it implies stronger
67 // guarantees, i.e., if it is smaller than the existing tracked seq.
70 it
->second
.num_reads
++;
72 it
->second
.num_writes
++;
75 it
->second
.exclusive
= it
->second
.exclusive
|| r
.exclusive
;
78 UntrackStatus
PointLockTracker::Untrack(const PointLockRequest
& r
) {
79 auto cf_keys
= tracked_keys_
.find(r
.column_family_id
);
80 if (cf_keys
== tracked_keys_
.end()) {
81 return UntrackStatus::NOT_TRACKED
;
84 auto& keys
= cf_keys
->second
;
85 auto it
= keys
.find(r
.key
);
86 if (it
== keys
.end()) {
87 return UntrackStatus::NOT_TRACKED
;
90 bool untracked
= false;
91 auto& info
= it
->second
;
93 if (info
.num_reads
> 0) {
98 if (info
.num_writes
> 0) {
104 bool removed
= false;
105 if (info
.num_reads
== 0 && info
.num_writes
== 0) {
108 tracked_keys_
.erase(cf_keys
);
114 return UntrackStatus::REMOVED
;
117 return UntrackStatus::UNTRACKED
;
119 return UntrackStatus::NOT_TRACKED
;
122 void PointLockTracker::Merge(const LockTracker
& tracker
) {
123 const PointLockTracker
& t
= static_cast<const PointLockTracker
&>(tracker
);
124 for (const auto& cf_keys
: t
.tracked_keys_
) {
125 ColumnFamilyId cf
= cf_keys
.first
;
126 const auto& keys
= cf_keys
.second
;
128 auto current_cf_keys
= tracked_keys_
.find(cf
);
129 if (current_cf_keys
== tracked_keys_
.end()) {
130 tracked_keys_
.emplace(cf_keys
);
132 auto& current_keys
= current_cf_keys
->second
;
133 for (const auto& key_info
: keys
) {
134 const std::string
& key
= key_info
.first
;
135 const TrackedKeyInfo
& info
= key_info
.second
;
136 // If key was not previously tracked, just copy the whole struct over.
137 // Otherwise, some merging needs to occur.
138 auto current_info
= current_keys
.find(key
);
139 if (current_info
== current_keys
.end()) {
140 current_keys
.emplace(key_info
);
142 current_info
->second
.Merge(info
);
149 void PointLockTracker::Subtract(const LockTracker
& tracker
) {
150 const PointLockTracker
& t
= static_cast<const PointLockTracker
&>(tracker
);
151 for (const auto& cf_keys
: t
.tracked_keys_
) {
152 ColumnFamilyId cf
= cf_keys
.first
;
153 const auto& keys
= cf_keys
.second
;
155 auto& current_keys
= tracked_keys_
.at(cf
);
156 for (const auto& key_info
: keys
) {
157 const std::string
& key
= key_info
.first
;
158 const TrackedKeyInfo
& info
= key_info
.second
;
159 uint32_t num_reads
= info
.num_reads
;
160 uint32_t num_writes
= info
.num_writes
;
162 auto current_key_info
= current_keys
.find(key
);
163 assert(current_key_info
!= current_keys
.end());
165 // Decrement the total reads/writes of this key by the number of
166 // reads/writes done since the last SavePoint.
168 assert(current_key_info
->second
.num_reads
>= num_reads
);
169 current_key_info
->second
.num_reads
-= num_reads
;
171 if (num_writes
> 0) {
172 assert(current_key_info
->second
.num_writes
>= num_writes
);
173 current_key_info
->second
.num_writes
-= num_writes
;
175 if (current_key_info
->second
.num_reads
== 0 &&
176 current_key_info
->second
.num_writes
== 0) {
177 current_keys
.erase(current_key_info
);
183 LockTracker
* PointLockTracker::GetTrackedLocksSinceSavePoint(
184 const LockTracker
& save_point_tracker
) const {
185 // Examine the number of reads/writes performed on all keys written
186 // since the last SavePoint and compare to the total number of reads/writes
188 LockTracker
* t
= new PointLockTracker();
189 const PointLockTracker
& save_point_t
=
190 static_cast<const PointLockTracker
&>(save_point_tracker
);
191 for (const auto& cf_keys
: save_point_t
.tracked_keys_
) {
192 ColumnFamilyId cf
= cf_keys
.first
;
193 const auto& keys
= cf_keys
.second
;
195 auto& current_keys
= tracked_keys_
.at(cf
);
196 for (const auto& key_info
: keys
) {
197 const std::string
& key
= key_info
.first
;
198 const TrackedKeyInfo
& info
= key_info
.second
;
199 uint32_t num_reads
= info
.num_reads
;
200 uint32_t num_writes
= info
.num_writes
;
202 auto current_key_info
= current_keys
.find(key
);
203 assert(current_key_info
!= current_keys
.end());
204 assert(current_key_info
->second
.num_reads
>= num_reads
);
205 assert(current_key_info
->second
.num_writes
>= num_writes
);
207 if (current_key_info
->second
.num_reads
== num_reads
&&
208 current_key_info
->second
.num_writes
== num_writes
) {
209 // All the reads/writes to this key were done in the last savepoint.
211 r
.column_family_id
= cf
;
214 r
.read_only
= (num_writes
== 0);
215 r
.exclusive
= info
.exclusive
;
223 PointLockStatus
PointLockTracker::GetPointLockStatus(
224 ColumnFamilyId column_family_id
, const std::string
& key
) const {
225 assert(IsPointLockSupported());
226 PointLockStatus status
;
227 auto it
= tracked_keys_
.find(column_family_id
);
228 if (it
== tracked_keys_
.end()) {
232 const auto& keys
= it
->second
;
233 auto key_it
= keys
.find(key
);
234 if (key_it
== keys
.end()) {
238 const TrackedKeyInfo
& key_info
= key_it
->second
;
239 status
.locked
= true;
240 status
.exclusive
= key_info
.exclusive
;
241 status
.seq
= key_info
.seq
;
245 uint64_t PointLockTracker::GetNumPointLocks() const {
246 uint64_t num_keys
= 0;
247 for (const auto& cf_keys
: tracked_keys_
) {
248 num_keys
+= cf_keys
.second
.size();
253 LockTracker::ColumnFamilyIterator
* PointLockTracker::GetColumnFamilyIterator()
255 return new TrackedKeysColumnFamilyIterator(tracked_keys_
);
258 LockTracker::KeyIterator
* PointLockTracker::GetKeyIterator(
259 ColumnFamilyId column_family_id
) const {
260 assert(tracked_keys_
.find(column_family_id
) != tracked_keys_
.end());
261 return new TrackedKeysIterator(tracked_keys_
, column_family_id
);
264 void PointLockTracker::Clear() { tracked_keys_
.clear(); }
266 } // namespace ROCKSDB_NAMESPACE