]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one or more | |
3 | * contributor license agreements. See the NOTICE file distributed with | |
4 | * this work for additional information regarding copyright ownership. | |
5 | * The ASF licenses this file to You under the Apache License, Version 2.0 | |
6 | * (the "License"); you may not use this file except in compliance with | |
7 | * the License. You may obtain a copy of the License at | |
8 | * | |
9 | * http://www.apache.org/licenses/LICENSE-2.0 | |
10 | * | |
11 | * Unless required by applicable law or agreed to in writing, software | |
12 | * distributed under the License is distributed on an "AS IS" BASIS, | |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 | * See the License for the specific language governing permissions and | |
15 | * limitations under the License. | |
16 | */ | |
17 | ||
18 | package org.apache.arrow.algorithm.sort; | |
19 | ||
20 | import static org.apache.arrow.vector.complex.BaseRepeatedValueVector.OFFSET_WIDTH; | |
21 | ||
22 | import org.apache.arrow.memory.util.ArrowBufPointer; | |
23 | import org.apache.arrow.memory.util.ByteFunctionHelpers; | |
24 | import org.apache.arrow.vector.BaseFixedWidthVector; | |
25 | import org.apache.arrow.vector.BaseVariableWidthVector; | |
26 | import org.apache.arrow.vector.BigIntVector; | |
27 | import org.apache.arrow.vector.Float4Vector; | |
28 | import org.apache.arrow.vector.Float8Vector; | |
29 | import org.apache.arrow.vector.IntVector; | |
30 | import org.apache.arrow.vector.SmallIntVector; | |
31 | import org.apache.arrow.vector.TinyIntVector; | |
32 | import org.apache.arrow.vector.UInt1Vector; | |
33 | import org.apache.arrow.vector.UInt2Vector; | |
34 | import org.apache.arrow.vector.UInt4Vector; | |
35 | import org.apache.arrow.vector.UInt8Vector; | |
36 | import org.apache.arrow.vector.ValueVector; | |
37 | import org.apache.arrow.vector.complex.BaseRepeatedValueVector; | |
38 | ||
39 | /** | |
40 | * Default comparator implementations for different types of vectors. | |
41 | */ | |
42 | public class DefaultVectorComparators { | |
43 | ||
44 | /** | |
45 | * Create the default comparator for the vector. | |
46 | * @param vector the vector. | |
47 | * @param <T> the vector type. | |
48 | * @return the default comparator. | |
49 | */ | |
50 | public static <T extends ValueVector> VectorValueComparator<T> createDefaultComparator(T vector) { | |
51 | if (vector instanceof BaseFixedWidthVector) { | |
52 | if (vector instanceof TinyIntVector) { | |
53 | return (VectorValueComparator<T>) new ByteComparator(); | |
54 | } else if (vector instanceof SmallIntVector) { | |
55 | return (VectorValueComparator<T>) new ShortComparator(); | |
56 | } else if (vector instanceof IntVector) { | |
57 | return (VectorValueComparator<T>) new IntComparator(); | |
58 | } else if (vector instanceof BigIntVector) { | |
59 | return (VectorValueComparator<T>) new LongComparator(); | |
60 | } else if (vector instanceof Float4Vector) { | |
61 | return (VectorValueComparator<T>) new Float4Comparator(); | |
62 | } else if (vector instanceof Float8Vector) { | |
63 | return (VectorValueComparator<T>) new Float8Comparator(); | |
64 | } else if (vector instanceof UInt1Vector) { | |
65 | return (VectorValueComparator<T>) new UInt1Comparator(); | |
66 | } else if (vector instanceof UInt2Vector) { | |
67 | return (VectorValueComparator<T>) new UInt2Comparator(); | |
68 | } else if (vector instanceof UInt4Vector) { | |
69 | return (VectorValueComparator<T>) new UInt4Comparator(); | |
70 | } else if (vector instanceof UInt8Vector) { | |
71 | return (VectorValueComparator<T>) new UInt8Comparator(); | |
72 | } | |
73 | } else if (vector instanceof BaseVariableWidthVector) { | |
74 | return (VectorValueComparator<T>) new VariableWidthComparator(); | |
75 | } else if (vector instanceof BaseRepeatedValueVector) { | |
76 | VectorValueComparator<?> innerComparator = | |
77 | createDefaultComparator(((BaseRepeatedValueVector) vector).getDataVector()); | |
78 | return new RepeatedValueComparator(innerComparator); | |
79 | } | |
80 | ||
81 | throw new IllegalArgumentException("No default comparator for " + vector.getClass().getCanonicalName()); | |
82 | } | |
83 | ||
84 | /** | |
85 | * Default comparator for bytes. | |
86 | * The comparison is based on values, with null comes first. | |
87 | */ | |
88 | public static class ByteComparator extends VectorValueComparator<TinyIntVector> { | |
89 | ||
90 | public ByteComparator() { | |
91 | super(Byte.SIZE / 8); | |
92 | } | |
93 | ||
94 | @Override | |
95 | public int compareNotNull(int index1, int index2) { | |
96 | byte value1 = vector1.get(index1); | |
97 | byte value2 = vector2.get(index2); | |
98 | return value1 - value2; | |
99 | } | |
100 | ||
101 | @Override | |
102 | public VectorValueComparator<TinyIntVector> createNew() { | |
103 | return new ByteComparator(); | |
104 | } | |
105 | } | |
106 | ||
107 | /** | |
108 | * Default comparator for short integers. | |
109 | * The comparison is based on values, with null comes first. | |
110 | */ | |
111 | public static class ShortComparator extends VectorValueComparator<SmallIntVector> { | |
112 | ||
113 | public ShortComparator() { | |
114 | super(Short.SIZE / 8); | |
115 | } | |
116 | ||
117 | @Override | |
118 | public int compareNotNull(int index1, int index2) { | |
119 | short value1 = vector1.get(index1); | |
120 | short value2 = vector2.get(index2); | |
121 | return value1 - value2; | |
122 | } | |
123 | ||
124 | @Override | |
125 | public VectorValueComparator<SmallIntVector> createNew() { | |
126 | return new ShortComparator(); | |
127 | } | |
128 | } | |
129 | ||
130 | /** | |
131 | * Default comparator for 32-bit integers. | |
132 | * The comparison is based on int values, with null comes first. | |
133 | */ | |
134 | public static class IntComparator extends VectorValueComparator<IntVector> { | |
135 | ||
136 | public IntComparator() { | |
137 | super(Integer.SIZE / 8); | |
138 | } | |
139 | ||
140 | @Override | |
141 | public int compareNotNull(int index1, int index2) { | |
142 | int value1 = vector1.get(index1); | |
143 | int value2 = vector2.get(index2); | |
144 | return Integer.compare(value1, value2); | |
145 | } | |
146 | ||
147 | @Override | |
148 | public VectorValueComparator<IntVector> createNew() { | |
149 | return new IntComparator(); | |
150 | } | |
151 | } | |
152 | ||
153 | /** | |
154 | * Default comparator for long integers. | |
155 | * The comparison is based on values, with null comes first. | |
156 | */ | |
157 | public static class LongComparator extends VectorValueComparator<BigIntVector> { | |
158 | ||
159 | public LongComparator() { | |
160 | super(Long.SIZE / 8); | |
161 | } | |
162 | ||
163 | @Override | |
164 | public int compareNotNull(int index1, int index2) { | |
165 | long value1 = vector1.get(index1); | |
166 | long value2 = vector2.get(index2); | |
167 | ||
168 | return Long.compare(value1, value2); | |
169 | } | |
170 | ||
171 | @Override | |
172 | public VectorValueComparator<BigIntVector> createNew() { | |
173 | return new LongComparator(); | |
174 | } | |
175 | } | |
176 | ||
177 | /** | |
178 | * Default comparator for unsigned bytes. | |
179 | * The comparison is based on values, with null comes first. | |
180 | */ | |
181 | public static class UInt1Comparator extends VectorValueComparator<UInt1Vector> { | |
182 | ||
183 | public UInt1Comparator() { | |
184 | super(1); | |
185 | } | |
186 | ||
187 | @Override | |
188 | public int compareNotNull(int index1, int index2) { | |
189 | byte value1 = vector1.get(index1); | |
190 | byte value2 = vector2.get(index2); | |
191 | ||
192 | return (value1 & 0xff) - (value2 & 0xff); | |
193 | } | |
194 | ||
195 | @Override | |
196 | public VectorValueComparator<UInt1Vector> createNew() { | |
197 | return new UInt1Comparator(); | |
198 | } | |
199 | } | |
200 | ||
201 | /** | |
202 | * Default comparator for unsigned short integer. | |
203 | * The comparison is based on values, with null comes first. | |
204 | */ | |
205 | public static class UInt2Comparator extends VectorValueComparator<UInt2Vector> { | |
206 | ||
207 | public UInt2Comparator() { | |
208 | super(2); | |
209 | } | |
210 | ||
211 | @Override | |
212 | public int compareNotNull(int index1, int index2) { | |
213 | char value1 = vector1.get(index1); | |
214 | char value2 = vector2.get(index2); | |
215 | ||
216 | // please note that we should not use the built-in | |
217 | // Character#compare method here, as that method | |
218 | // essentially compares char values as signed integers. | |
219 | return (value1 & 0xffff) - (value2 & 0xffff); | |
220 | } | |
221 | ||
222 | @Override | |
223 | public VectorValueComparator<UInt2Vector> createNew() { | |
224 | return new UInt2Comparator(); | |
225 | } | |
226 | } | |
227 | ||
228 | /** | |
229 | * Default comparator for unsigned integer. | |
230 | * The comparison is based on values, with null comes first. | |
231 | */ | |
232 | public static class UInt4Comparator extends VectorValueComparator<UInt4Vector> { | |
233 | ||
234 | public UInt4Comparator() { | |
235 | super(4); | |
236 | } | |
237 | ||
238 | @Override | |
239 | public int compareNotNull(int index1, int index2) { | |
240 | int value1 = vector1.get(index1); | |
241 | int value2 = vector2.get(index2); | |
242 | return ByteFunctionHelpers.unsignedIntCompare(value1, value2); | |
243 | } | |
244 | ||
245 | @Override | |
246 | public VectorValueComparator<UInt4Vector> createNew() { | |
247 | return new UInt4Comparator(); | |
248 | } | |
249 | } | |
250 | ||
251 | /** | |
252 | * Default comparator for unsigned long integer. | |
253 | * The comparison is based on values, with null comes first. | |
254 | */ | |
255 | public static class UInt8Comparator extends VectorValueComparator<UInt8Vector> { | |
256 | ||
257 | public UInt8Comparator() { | |
258 | super(8); | |
259 | } | |
260 | ||
261 | @Override | |
262 | public int compareNotNull(int index1, int index2) { | |
263 | long value1 = vector1.get(index1); | |
264 | long value2 = vector2.get(index2); | |
265 | return ByteFunctionHelpers.unsignedLongCompare(value1, value2); | |
266 | } | |
267 | ||
268 | @Override | |
269 | public VectorValueComparator<UInt8Vector> createNew() { | |
270 | return new UInt8Comparator(); | |
271 | } | |
272 | } | |
273 | ||
274 | /** | |
275 | * Default comparator for float type. | |
276 | * The comparison is based on values, with null comes first. | |
277 | */ | |
278 | public static class Float4Comparator extends VectorValueComparator<Float4Vector> { | |
279 | ||
280 | public Float4Comparator() { | |
281 | super(Float.SIZE / 8); | |
282 | } | |
283 | ||
284 | @Override | |
285 | public int compareNotNull(int index1, int index2) { | |
286 | float value1 = vector1.get(index1); | |
287 | float value2 = vector2.get(index2); | |
288 | ||
289 | boolean isNan1 = Float.isNaN(value1); | |
290 | boolean isNan2 = Float.isNaN(value2); | |
291 | if (isNan1 || isNan2) { | |
292 | if (isNan1 && isNan2) { | |
293 | return 0; | |
294 | } else if (isNan1) { | |
295 | // nan is greater than any normal value | |
296 | return 1; | |
297 | } else { | |
298 | return -1; | |
299 | } | |
300 | } | |
301 | ||
302 | return (int) Math.signum(value1 - value2); | |
303 | } | |
304 | ||
305 | @Override | |
306 | public VectorValueComparator<Float4Vector> createNew() { | |
307 | return new Float4Comparator(); | |
308 | } | |
309 | } | |
310 | ||
311 | /** | |
312 | * Default comparator for double type. | |
313 | * The comparison is based on values, with null comes first. | |
314 | */ | |
315 | public static class Float8Comparator extends VectorValueComparator<Float8Vector> { | |
316 | ||
317 | public Float8Comparator() { | |
318 | super(Double.SIZE / 8); | |
319 | } | |
320 | ||
321 | @Override | |
322 | public int compareNotNull(int index1, int index2) { | |
323 | double value1 = vector1.get(index1); | |
324 | double value2 = vector2.get(index2); | |
325 | ||
326 | boolean isNan1 = Double.isNaN(value1); | |
327 | boolean isNan2 = Double.isNaN(value2); | |
328 | if (isNan1 || isNan2) { | |
329 | if (isNan1 && isNan2) { | |
330 | return 0; | |
331 | } else if (isNan1) { | |
332 | // nan is greater than any normal value | |
333 | return 1; | |
334 | } else { | |
335 | return -1; | |
336 | } | |
337 | } | |
338 | ||
339 | return (int) Math.signum(value1 - value2); | |
340 | } | |
341 | ||
342 | @Override | |
343 | public VectorValueComparator<Float8Vector> createNew() { | |
344 | return new Float8Comparator(); | |
345 | } | |
346 | } | |
347 | ||
348 | /** | |
349 | * Default comparator for {@link org.apache.arrow.vector.BaseVariableWidthVector}. | |
350 | * The comparison is in lexicographic order, with null comes first. | |
351 | */ | |
352 | public static class VariableWidthComparator extends VectorValueComparator<BaseVariableWidthVector> { | |
353 | ||
354 | private ArrowBufPointer reusablePointer1 = new ArrowBufPointer(); | |
355 | ||
356 | private ArrowBufPointer reusablePointer2 = new ArrowBufPointer(); | |
357 | ||
358 | @Override | |
359 | public int compare(int index1, int index2) { | |
360 | vector1.getDataPointer(index1, reusablePointer1); | |
361 | vector2.getDataPointer(index2, reusablePointer2); | |
362 | return reusablePointer1.compareTo(reusablePointer2); | |
363 | } | |
364 | ||
365 | @Override | |
366 | public int compareNotNull(int index1, int index2) { | |
367 | vector1.getDataPointer(index1, reusablePointer1); | |
368 | vector2.getDataPointer(index2, reusablePointer2); | |
369 | return reusablePointer1.compareTo(reusablePointer2); | |
370 | } | |
371 | ||
372 | @Override | |
373 | public VectorValueComparator<BaseVariableWidthVector> createNew() { | |
374 | return new VariableWidthComparator(); | |
375 | } | |
376 | } | |
377 | ||
378 | /** | |
379 | * Default comparator for {@link BaseRepeatedValueVector}. | |
380 | * It works by comparing the underlying vector in a lexicographic order. | |
381 | * @param <T> inner vector type. | |
382 | */ | |
383 | public static class RepeatedValueComparator<T extends ValueVector> | |
384 | extends VectorValueComparator<BaseRepeatedValueVector> { | |
385 | ||
386 | private VectorValueComparator<T> innerComparator; | |
387 | ||
388 | public RepeatedValueComparator(VectorValueComparator<T> innerComparator) { | |
389 | this.innerComparator = innerComparator; | |
390 | } | |
391 | ||
392 | @Override | |
393 | public int compareNotNull(int index1, int index2) { | |
394 | int startIdx1 = vector1.getOffsetBuffer().getInt(index1 * OFFSET_WIDTH); | |
395 | int startIdx2 = vector2.getOffsetBuffer().getInt(index2 * OFFSET_WIDTH); | |
396 | ||
397 | int endIdx1 = vector1.getOffsetBuffer().getInt((index1 + 1) * OFFSET_WIDTH); | |
398 | int endIdx2 = vector2.getOffsetBuffer().getInt((index2 + 1) * OFFSET_WIDTH); | |
399 | ||
400 | int length1 = endIdx1 - startIdx1; | |
401 | int length2 = endIdx2 - startIdx2; | |
402 | ||
403 | int length = length1 < length2 ? length1 : length2; | |
404 | ||
405 | for (int i = 0; i < length; i++) { | |
406 | int result = innerComparator.compare(startIdx1 + i, startIdx2 + i); | |
407 | if (result != 0) { | |
408 | return result; | |
409 | } | |
410 | } | |
411 | return length1 - length2; | |
412 | } | |
413 | ||
414 | @Override | |
415 | public VectorValueComparator<BaseRepeatedValueVector> createNew() { | |
416 | VectorValueComparator<T> newInnerComparator = innerComparator.createNew(); | |
417 | return new RepeatedValueComparator(newInnerComparator); | |
418 | } | |
419 | ||
420 | @Override | |
421 | public void attachVectors(BaseRepeatedValueVector vector1, BaseRepeatedValueVector vector2) { | |
422 | this.vector1 = vector1; | |
423 | this.vector2 = vector2; | |
424 | ||
425 | innerComparator.attachVectors((T) vector1.getDataVector(), (T) vector2.getDataVector()); | |
426 | } | |
427 | } | |
428 | ||
429 | private DefaultVectorComparators() { | |
430 | } | |
431 | } |