]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / java / vector / src / main / java / org / apache / arrow / vector / BitVectorHelper.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.arrow.vector;
19
20 import static io.netty.util.internal.PlatformDependent.getByte;
21 import static io.netty.util.internal.PlatformDependent.getInt;
22 import static io.netty.util.internal.PlatformDependent.getLong;
23 import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt;
24
25 import org.apache.arrow.memory.ArrowBuf;
26 import org.apache.arrow.memory.BoundsChecking;
27 import org.apache.arrow.memory.BufferAllocator;
28 import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
29 import org.apache.arrow.vector.util.DataSizeRoundingUtil;
30
31 import io.netty.util.internal.PlatformDependent;
32
33 /**
34 * Helper class for performing generic operations on a bit vector buffer.
35 * External use of this class is not recommended.
36 */
37 public class BitVectorHelper {
38
39 private BitVectorHelper() {}
40
41 /**
42 * Get the index of byte corresponding to bit index in validity buffer.
43 */
44 public static long byteIndex(long absoluteBitIndex) {
45 return absoluteBitIndex >> 3;
46 }
47
48 /**
49 * Get the relative index of bit within the byte in validity buffer.
50 */
51 public static int bitIndex(long absoluteBitIndex) {
52 return checkedCastToInt(absoluteBitIndex & 7);
53 }
54
55 /**
56 * Get the index of byte corresponding to bit index in validity buffer.
57 */
58 public static int byteIndex(int absoluteBitIndex) {
59 return absoluteBitIndex >> 3;
60 }
61
62 /**
63 * Get the relative index of bit within the byte in validity buffer.
64 */
65 public static int bitIndex(int absoluteBitIndex) {
66 return absoluteBitIndex & 7;
67 }
68
69 /**
70 * Set the bit at provided index to 1.
71 *
72 * @param validityBuffer validity buffer of the vector
73 * @param index index to be set
74 */
75 public static void setBit(ArrowBuf validityBuffer, long index) {
76 // it can be observed that some logic is duplicate of the logic in setValidityBit.
77 // this is because JIT cannot always remove the if branch in setValidityBit,
78 // so we give a dedicated implementation for setting bits.
79 final long byteIndex = byteIndex(index);
80 final int bitIndex = bitIndex(index);
81
82 // the byte is promoted to an int, because according to Java specification,
83 // bytes will be promoted to ints automatically, upon expression evaluation.
84 // by promoting it manually, we avoid the unnecessary conversions.
85 int currentByte = validityBuffer.getByte(byteIndex);
86 final int bitMask = 1 << bitIndex;
87 currentByte |= bitMask;
88 validityBuffer.setByte(byteIndex, currentByte);
89 }
90
91 /**
92 * Set the bit at provided index to 0.
93 *
94 * @param validityBuffer validity buffer of the vector
95 * @param index index to be set
96 */
97 public static void unsetBit(ArrowBuf validityBuffer, int index) {
98 // it can be observed that some logic is duplicate of the logic in setValidityBit.
99 // this is because JIT cannot always remove the if branch in setValidityBit,
100 // so we give a dedicated implementation for unsetting bits.
101 final int byteIndex = byteIndex(index);
102 final int bitIndex = bitIndex(index);
103
104 // the byte is promoted to an int, because according to Java specification,
105 // bytes will be promoted to ints automatically, upon expression evaluation.
106 // by promoting it manually, we avoid the unnecessary conversions.
107 int currentByte = validityBuffer.getByte(byteIndex);
108 final int bitMask = 1 << bitIndex;
109 currentByte &= ~bitMask;
110 validityBuffer.setByte(byteIndex, currentByte);
111 }
112
113 /**
114 * Set the bit at a given index to provided value (1 or 0).
115 *
116 * @param validityBuffer validity buffer of the vector
117 * @param index index to be set
118 * @param value value to set
119 */
120 public static void setValidityBit(ArrowBuf validityBuffer, int index, int value) {
121 final int byteIndex = byteIndex(index);
122 final int bitIndex = bitIndex(index);
123
124 // the byte is promoted to an int, because according to Java specification,
125 // bytes will be promoted to ints automatically, upon expression evaluation.
126 // by promoting it manually, we avoid the unnecessary conversions.
127 int currentByte = validityBuffer.getByte(byteIndex);
128 final int bitMask = 1 << bitIndex;
129 if (value != 0) {
130 currentByte |= bitMask;
131 } else {
132 currentByte &= ~bitMask;
133 }
134 validityBuffer.setByte(byteIndex, currentByte);
135 }
136
137 /**
138 * Set the bit at a given index to provided value (1 or 0). Internally
139 * takes care of allocating the buffer if the caller didn't do so.
140 *
141 * @param validityBuffer validity buffer of the vector
142 * @param allocator allocator for the buffer
143 * @param valueCount number of values to allocate/set
144 * @param index index to be set
145 * @param value value to set
146 * @return ArrowBuf
147 */
148 public static ArrowBuf setValidityBit(ArrowBuf validityBuffer, BufferAllocator allocator,
149 int valueCount, int index, int value) {
150 if (validityBuffer == null) {
151 validityBuffer = allocator.buffer(getValidityBufferSize(valueCount));
152 }
153 setValidityBit(validityBuffer, index, value);
154 if (index == (valueCount - 1)) {
155 validityBuffer.writerIndex(getValidityBufferSize(valueCount));
156 }
157
158 return validityBuffer;
159 }
160
161 /**
162 * Check if a bit at a given index is set or not.
163 *
164 * @param buffer buffer to check
165 * @param index index of the buffer
166 * @return 1 if bit is set, 0 otherwise.
167 */
168 public static int get(final ArrowBuf buffer, int index) {
169 final int byteIndex = index >> 3;
170 final byte b = buffer.getByte(byteIndex);
171 final int bitIndex = index & 7;
172 return (b >> bitIndex) & 0x01;
173 }
174
175 /**
176 * Compute the size of validity buffer required to manage a given number
177 * of elements in a vector.
178 *
179 * @param valueCount number of elements in the vector
180 * @return buffer size
181 */
182 public static int getValidityBufferSize(int valueCount) {
183 return DataSizeRoundingUtil.divideBy8Ceil(valueCount);
184 }
185
186 /**
187 * Given a validity buffer, find the number of bits that are not set.
188 * This is used to compute the number of null elements in a nullable vector.
189 *
190 * @param validityBuffer validity buffer of the vector
191 * @param valueCount number of values in the vector
192 * @return number of bits not set.
193 */
194 public static int getNullCount(final ArrowBuf validityBuffer, final int valueCount) {
195 if (valueCount == 0) {
196 return 0;
197 }
198 int count = 0;
199 final int sizeInBytes = getValidityBufferSize(valueCount);
200 // If value count is not a multiple of 8, then calculate number of used bits in the last byte
201 final int remainder = valueCount % 8;
202 final int fullBytesCount = remainder == 0 ? sizeInBytes : sizeInBytes - 1;
203
204 int index = 0;
205 while (index + 8 <= fullBytesCount) {
206 long longValue = validityBuffer.getLong(index);
207 count += Long.bitCount(longValue);
208 index += 8;
209 }
210
211 if (index + 4 <= fullBytesCount) {
212 int intValue = validityBuffer.getInt(index);
213 count += Integer.bitCount(intValue);
214 index += 4;
215 }
216
217 while (index < fullBytesCount) {
218 byte byteValue = validityBuffer.getByte(index);
219 count += Integer.bitCount(byteValue & 0xFF);
220 index += 1;
221 }
222
223 // handling with the last bits
224 if (remainder != 0) {
225 byte byteValue = validityBuffer.getByte(sizeInBytes - 1);
226
227 // making the remaining bits all 1s if it is not fully filled
228 byte mask = (byte) (0xFF << remainder);
229 byteValue = (byte) (byteValue | mask);
230 count += Integer.bitCount(byteValue & 0xFF);
231 }
232
233 return 8 * sizeInBytes - count;
234 }
235
236 /**
237 * Tests if all bits in a validity buffer are equal 0 or 1, according to the specified parameter.
238 * @param validityBuffer the validity buffer.
239 * @param valueCount the bit count.
240 * @param checkOneBits if set to true, the method checks if all bits are equal to 1;
241 * otherwise, it checks if all bits are equal to 0.
242 * @return true if all bits are 0 or 1 according to the parameter, and false otherwise.
243 */
244 public static boolean checkAllBitsEqualTo(
245 final ArrowBuf validityBuffer, final int valueCount, final boolean checkOneBits) {
246 if (valueCount == 0) {
247 return true;
248 }
249 final int sizeInBytes = getValidityBufferSize(valueCount);
250
251 // boundary check
252 validityBuffer.checkBytes(0, sizeInBytes);
253
254 // If value count is not a multiple of 8, then calculate number of used bits in the last byte
255 final int remainder = valueCount % 8;
256 final int fullBytesCount = remainder == 0 ? sizeInBytes : sizeInBytes - 1;
257
258 // the integer number to compare against
259 final int intToCompare = checkOneBits ? -1 : 0;
260
261 int index = 0;
262 while (index + 8 <= fullBytesCount) {
263 long longValue = getLong(validityBuffer.memoryAddress() + index);
264 if (longValue != (long) intToCompare) {
265 return false;
266 }
267 index += 8;
268 }
269
270 if (index + 4 <= fullBytesCount) {
271 int intValue = getInt(validityBuffer.memoryAddress() + index);
272 if (intValue != intToCompare) {
273 return false;
274 }
275 index += 4;
276 }
277
278 while (index < fullBytesCount) {
279 byte byteValue = getByte(validityBuffer.memoryAddress() + index);
280 if (byteValue != (byte) intToCompare) {
281 return false;
282 }
283 index += 1;
284 }
285
286 // handling with the last bits
287 if (remainder != 0) {
288 byte byteValue = getByte(validityBuffer.memoryAddress() + sizeInBytes - 1);
289 byte mask = (byte) ((1 << remainder) - 1);
290 byteValue = (byte) (byteValue & mask);
291 if (checkOneBits) {
292 if ((mask & byteValue) != mask) {
293 return false;
294 }
295 } else {
296 if (byteValue != (byte) 0) {
297 return false;
298 }
299 }
300 }
301 return true;
302 }
303
304 /** Returns the byte at index from data right-shifted by offset. */
305 public static byte getBitsFromCurrentByte(final ArrowBuf data, final int index, final int offset) {
306 return (byte) ((data.getByte(index) & 0xFF) >>> offset);
307 }
308
309 /**
310 * Returns the byte at <code>index</code> from left-shifted by (8 - <code>offset</code>).
311 */
312 public static byte getBitsFromNextByte(ArrowBuf data, int index, int offset) {
313 return (byte) ((data.getByte(index) << (8 - offset)));
314 }
315
316 /**
317 * Returns a new buffer if the source validity buffer is either all null or all
318 * not-null, otherwise returns a buffer pointing to the same memory as source.
319 *
320 * @param fieldNode The fieldNode containing the null count
321 * @param sourceValidityBuffer The source validity buffer that will have its
322 * position copied if there is a mix of null and non-null values
323 * @param allocator The allocator to use for creating a new buffer if necessary.
324 * @return A new buffer that is either allocated or points to the same memory as sourceValidityBuffer.
325 */
326 public static ArrowBuf loadValidityBuffer(final ArrowFieldNode fieldNode,
327 final ArrowBuf sourceValidityBuffer,
328 final BufferAllocator allocator) {
329 final int valueCount = fieldNode.getLength();
330 ArrowBuf newBuffer = null;
331 /* either all NULLs or all non-NULLs */
332 if (fieldNode.getNullCount() == 0 || fieldNode.getNullCount() == valueCount) {
333 newBuffer = allocator.buffer(getValidityBufferSize(valueCount));
334 newBuffer.setZero(0, newBuffer.capacity());
335 if (fieldNode.getNullCount() != 0) {
336 /* all NULLs */
337 return newBuffer;
338 }
339 /* all non-NULLs */
340 int fullBytesCount = valueCount / 8;
341 newBuffer.setOne(0, fullBytesCount);
342 int remainder = valueCount % 8;
343 if (remainder > 0) {
344 byte bitMask = (byte) (0xFFL >>> ((8 - remainder) & 7));
345 newBuffer.setByte(fullBytesCount, bitMask);
346 }
347 } else {
348 /* mixed byte pattern -- create another ArrowBuf associated with the
349 * target allocator
350 */
351 newBuffer = sourceValidityBuffer.getReferenceManager().retain(sourceValidityBuffer, allocator);
352 }
353
354 return newBuffer;
355 }
356
357 /**
358 * Set the byte of the given index in the data buffer by applying a bit mask to
359 * the current byte at that index.
360 *
361 * @param data buffer to set
362 * @param byteIndex byteIndex within the buffer
363 * @param bitMask bit mask to be set
364 */
365 static void setBitMaskedByte(ArrowBuf data, int byteIndex, byte bitMask) {
366 byte currentByte = data.getByte(byteIndex);
367 currentByte |= bitMask;
368 data.setByte(byteIndex, currentByte);
369 }
370
371 /**
372 * Concat two validity buffers.
373 * @param input1 the first validity buffer.
374 * @param numBits1 the number of bits in the first validity buffer.
375 * @param input2 the second validity buffer.
376 * @param numBits2 the number of bits in the second validity buffer.
377 * @param output the output validity buffer. It can be the same one as the first input.
378 * The caller must make sure the output buffer has enough capacity.
379 */
380 public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, int numBits2, ArrowBuf output) {
381 int numBytes1 = DataSizeRoundingUtil.divideBy8Ceil(numBits1);
382 int numBytes2 = DataSizeRoundingUtil.divideBy8Ceil(numBits2);
383 int numBytesOut = DataSizeRoundingUtil.divideBy8Ceil(numBits1 + numBits2);
384
385 if (BoundsChecking.BOUNDS_CHECKING_ENABLED) {
386 output.checkBytes(0, numBytesOut);
387 }
388
389 // copy the first bit set
390 if (input1 != output) {
391 PlatformDependent.copyMemory(input1.memoryAddress(), output.memoryAddress(), numBytes1);
392 }
393
394 if (bitIndex(numBits1) == 0) {
395 // The number of bits for the first bit set is a multiple of 8, so the boundary is at byte boundary.
396 // For this case, we have a shortcut to copy all bytes from the second set after the byte boundary.
397 PlatformDependent.copyMemory(input2.memoryAddress(), output.memoryAddress() + numBytes1, numBytes2);
398 return;
399 }
400
401 // the number of bits to fill a full byte after the first input is processed
402 int numBitsToFill = 8 - bitIndex(numBits1);
403
404 // mask to clear high bits
405 int mask = (1 << (8 - numBitsToFill)) - 1;
406
407 int numFullBytes = numBits2 / 8;
408
409 int prevByte = output.getByte(numBytes1 - 1) & mask;
410 for (int i = 0; i < numFullBytes; i++) {
411 int curByte = input2.getByte(i) & 0xff;
412
413 // first fill the bits to a full byte
414 int byteToFill = (curByte << (8 - numBitsToFill)) & 0xff;
415 output.setByte(numBytes1 + i - 1, byteToFill | prevByte);
416
417 // fill remaining bits in the current byte
418 // note that it is also the previous byte for the next iteration
419 prevByte = curByte >>> numBitsToFill;
420 }
421
422 int lastOutputByte = prevByte;
423
424 // the number of extra bits for the second input, relative to full bytes
425 int numTrailingBits = bitIndex(numBits2);
426
427 if (numTrailingBits == 0) {
428 output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte);
429 return;
430 }
431
432 // process remaining bits from input2
433 int remByte = input2.getByte(numBytes2 - 1) & 0xff;
434
435 int byteToFill = remByte << (8 - numBitsToFill);
436 lastOutputByte |= byteToFill;
437
438 output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte);
439
440 if (numTrailingBits > numBitsToFill) {
441 // clear all bits for the last byte before writing
442 output.setByte(numBytes1 + numFullBytes, 0);
443
444 // some remaining bits cannot be filled in the previous byte
445 int leftByte = remByte >>> numBitsToFill;
446 output.setByte(numBytes1 + numFullBytes, leftByte);
447 }
448 }
449 }