]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernels / scalar_arithmetic.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <algorithm>
19 #include <cmath>
20 #include <limits>
21 #include <utility>
22 #include <vector>
23
24 #include "arrow/compare.h"
25 #include "arrow/compute/api_scalar.h"
26 #include "arrow/compute/kernels/common.h"
27 #include "arrow/compute/kernels/util_internal.h"
28 #include "arrow/type.h"
29 #include "arrow/type_traits.h"
30 #include "arrow/util/decimal.h"
31 #include "arrow/util/int_util_internal.h"
32 #include "arrow/util/macros.h"
33
34 namespace arrow {
35
36 using internal::AddWithOverflow;
37 using internal::DivideWithOverflow;
38 using internal::MultiplyWithOverflow;
39 using internal::NegateWithOverflow;
40 using internal::SubtractWithOverflow;
41
42 namespace compute {
43 namespace internal {
44
45 using applicator::ScalarBinaryEqualTypes;
46 using applicator::ScalarBinaryNotNullEqualTypes;
47 using applicator::ScalarUnary;
48 using applicator::ScalarUnaryNotNull;
49 using applicator::ScalarUnaryNotNullStateful;
50
51 namespace {
52
53 // N.B. take care not to conflict with type_traits.h as that can cause surprises in a
54 // unity build
55
56 template <typename T>
57 using is_unsigned_integer = std::integral_constant<bool, std::is_integral<T>::value &&
58 std::is_unsigned<T>::value>;
59
60 template <typename T>
61 using is_signed_integer =
62 std::integral_constant<bool, std::is_integral<T>::value && std::is_signed<T>::value>;
63
64 template <typename T, typename R = T>
65 using enable_if_signed_c_integer = enable_if_t<is_signed_integer<T>::value, R>;
66
67 template <typename T, typename R = T>
68 using enable_if_unsigned_c_integer = enable_if_t<is_unsigned_integer<T>::value, R>;
69
70 template <typename T, typename R = T>
71 using enable_if_c_integer =
72 enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, R>;
73
74 template <typename T, typename R = T>
75 using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, R>;
76
77 template <typename T, typename R = T>
78 using enable_if_decimal_value =
79 enable_if_t<std::is_same<Decimal128, T>::value || std::is_same<Decimal256, T>::value,
80 R>;
81
82 struct AbsoluteValue {
83 template <typename T, typename Arg>
84 static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
85 Status*) {
86 return std::fabs(arg);
87 }
88
89 template <typename T, typename Arg>
90 static constexpr enable_if_unsigned_c_integer<Arg, T> Call(KernelContext*, Arg arg,
91 Status*) {
92 return arg;
93 }
94
95 template <typename T, typename Arg>
96 static constexpr enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg,
97 Status* st) {
98 return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg;
99 }
100 };
101
102 struct AbsoluteValueChecked {
103 template <typename T, typename Arg>
104 static enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
105 static_assert(std::is_same<T, Arg>::value, "");
106 if (arg == std::numeric_limits<Arg>::min()) {
107 *st = Status::Invalid("overflow");
108 return arg;
109 }
110 return std::abs(arg);
111 }
112
113 template <typename T, typename Arg>
114 static enable_if_unsigned_c_integer<Arg, T> Call(KernelContext* ctx, Arg arg,
115 Status* st) {
116 static_assert(std::is_same<T, Arg>::value, "");
117 return arg;
118 }
119
120 template <typename T, typename Arg>
121 static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
122 Status* st) {
123 static_assert(std::is_same<T, Arg>::value, "");
124 return std::fabs(arg);
125 }
126 };
127
128 struct Add {
129 template <typename T, typename Arg0, typename Arg1>
130 static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
131 Status*) {
132 return left + right;
133 }
134
135 template <typename T, typename Arg0, typename Arg1>
136 static constexpr enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg0 left,
137 Arg1 right, Status*) {
138 return left + right;
139 }
140
141 template <typename T, typename Arg0, typename Arg1>
142 static constexpr enable_if_signed_c_integer<T> Call(KernelContext*, Arg0 left,
143 Arg1 right, Status*) {
144 return arrow::internal::SafeSignedAdd(left, right);
145 }
146
147 template <typename T, typename Arg0, typename Arg1>
148 static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
149 return left + right;
150 }
151 };
152
153 struct AddChecked {
154 template <typename T, typename Arg0, typename Arg1>
155 static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
156 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
157 T result = 0;
158 if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) {
159 *st = Status::Invalid("overflow");
160 }
161 return result;
162 }
163
164 template <typename T, typename Arg0, typename Arg1>
165 static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
166 Status*) {
167 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
168 return left + right;
169 }
170
171 template <typename T, typename Arg0, typename Arg1>
172 static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
173 return left + right;
174 }
175 };
176
177 struct Subtract {
178 template <typename T, typename Arg0, typename Arg1>
179 static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
180 Status*) {
181 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
182 return left - right;
183 }
184
185 template <typename T, typename Arg0, typename Arg1>
186 static constexpr enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg0 left,
187 Arg1 right, Status*) {
188 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
189 return left - right;
190 }
191
192 template <typename T, typename Arg0, typename Arg1>
193 static constexpr enable_if_signed_c_integer<T> Call(KernelContext*, Arg0 left,
194 Arg1 right, Status*) {
195 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
196 return arrow::internal::SafeSignedSubtract(left, right);
197 }
198
199 template <typename T, typename Arg0, typename Arg1>
200 static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
201 return left + (-right);
202 }
203 };
204
205 struct SubtractChecked {
206 template <typename T, typename Arg0, typename Arg1>
207 static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
208 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
209 T result = 0;
210 if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
211 *st = Status::Invalid("overflow");
212 }
213 return result;
214 }
215
216 template <typename T, typename Arg0, typename Arg1>
217 static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
218 Status*) {
219 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
220 return left - right;
221 }
222
223 template <typename T, typename Arg0, typename Arg1>
224 static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
225 return left + (-right);
226 }
227 };
228
229 struct Multiply {
230 static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, "");
231 static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, "");
232 static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value, "");
233 static_assert(std::is_same<decltype(uint16_t() * uint16_t()), int32_t>::value, "");
234 static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value, "");
235 static_assert(std::is_same<decltype(uint32_t() * uint32_t()), uint32_t>::value, "");
236 static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value, "");
237 static_assert(std::is_same<decltype(uint64_t() * uint64_t()), uint64_t>::value, "");
238
239 template <typename T, typename Arg0, typename Arg1>
240 static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right,
241 Status*) {
242 return left * right;
243 }
244
245 template <typename T, typename Arg0, typename Arg1>
246 static constexpr enable_if_t<
247 is_unsigned_integer<T>::value && !std::is_same<T, uint16_t>::value, T>
248 Call(KernelContext*, T left, T right, Status*) {
249 return left * right;
250 }
251
252 template <typename T, typename Arg0, typename Arg1>
253 static constexpr enable_if_t<
254 is_signed_integer<T>::value && !std::is_same<T, int16_t>::value, T>
255 Call(KernelContext*, T left, T right, Status*) {
256 return to_unsigned(left) * to_unsigned(right);
257 }
258
259 // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit
260 // integer. However, some inputs may nevertheless overflow (which triggers undefined
261 // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is
262 // well defined.
263 template <typename T, typename Arg0, typename Arg1>
264 static constexpr enable_if_same<T, int16_t, T> Call(KernelContext*, int16_t left,
265 int16_t right, Status*) {
266 return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
267 }
268 template <typename T, typename Arg0, typename Arg1>
269 static constexpr enable_if_same<T, uint16_t, T> Call(KernelContext*, uint16_t left,
270 uint16_t right, Status*) {
271 return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
272 }
273
274 template <typename T, typename Arg0, typename Arg1>
275 static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
276 return left * right;
277 }
278 };
279
280 struct MultiplyChecked {
281 template <typename T, typename Arg0, typename Arg1>
282 static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
283 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
284 T result = 0;
285 if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) {
286 *st = Status::Invalid("overflow");
287 }
288 return result;
289 }
290
291 template <typename T, typename Arg0, typename Arg1>
292 static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
293 Status*) {
294 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
295 return left * right;
296 }
297
298 template <typename T, typename Arg0, typename Arg1>
299 static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
300 return left * right;
301 }
302 };
303
304 struct Divide {
305 template <typename T, typename Arg0, typename Arg1>
306 static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
307 Status*) {
308 return left / right;
309 }
310
311 template <typename T, typename Arg0, typename Arg1>
312 static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
313 T result;
314 if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
315 if (right == 0) {
316 *st = Status::Invalid("divide by zero");
317 } else {
318 result = 0;
319 }
320 }
321 return result;
322 }
323
324 template <typename T, typename Arg0, typename Arg1>
325 static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
326 Status* st) {
327 if (right == Arg1()) {
328 *st = Status::Invalid("Divide by zero");
329 return T();
330 } else {
331 return left / right;
332 }
333 }
334 };
335
336 struct DivideChecked {
337 template <typename T, typename Arg0, typename Arg1>
338 static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
339 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
340 T result;
341 if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
342 if (right == 0) {
343 *st = Status::Invalid("divide by zero");
344 } else {
345 *st = Status::Invalid("overflow");
346 }
347 }
348 return result;
349 }
350
351 template <typename T, typename Arg0, typename Arg1>
352 static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
353 Status* st) {
354 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
355 if (ARROW_PREDICT_FALSE(right == 0)) {
356 *st = Status::Invalid("divide by zero");
357 return 0;
358 }
359 return left / right;
360 }
361
362 template <typename T, typename Arg0, typename Arg1>
363 static enable_if_decimal_value<T> Call(KernelContext* ctx, Arg0 left, Arg1 right,
364 Status* st) {
365 return Divide::Call<T>(ctx, left, right, st);
366 }
367 };
368
369 struct Negate {
370 template <typename T, typename Arg>
371 static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg arg, Status*) {
372 return -arg;
373 }
374
375 template <typename T, typename Arg>
376 static constexpr enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg arg,
377 Status*) {
378 return ~arg + 1;
379 }
380
381 template <typename T, typename Arg>
382 static constexpr enable_if_signed_c_integer<T> Call(KernelContext*, Arg arg, Status*) {
383 return arrow::internal::SafeSignedNegate(arg);
384 }
385 };
386
387 struct NegateChecked {
388 template <typename T, typename Arg>
389 static enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
390 static_assert(std::is_same<T, Arg>::value, "");
391 T result = 0;
392 if (ARROW_PREDICT_FALSE(NegateWithOverflow(arg, &result))) {
393 *st = Status::Invalid("overflow");
394 }
395 return result;
396 }
397
398 template <typename T, typename Arg>
399 static enable_if_unsigned_c_integer<Arg, T> Call(KernelContext* ctx, Arg arg,
400 Status* st) {
401 static_assert(std::is_same<T, Arg>::value, "");
402 DCHECK(false) << "This is included only for the purposes of instantiability from the "
403 "arithmetic kernel generator";
404 return 0;
405 }
406
407 template <typename T, typename Arg>
408 static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
409 Status* st) {
410 static_assert(std::is_same<T, Arg>::value, "");
411 return -arg;
412 }
413 };
414
415 struct Power {
416 ARROW_NOINLINE
417 static uint64_t IntegerPower(uint64_t base, uint64_t exp) {
418 // right to left O(logn) power
419 uint64_t pow = 1;
420 while (exp) {
421 pow *= (exp & 1) ? base : 1;
422 base *= base;
423 exp >>= 1;
424 }
425 return pow;
426 }
427
428 template <typename T, typename Arg0, typename Arg1>
429 static enable_if_c_integer<T> Call(KernelContext*, T base, T exp, Status* st) {
430 if (exp < 0) {
431 *st = Status::Invalid("integers to negative integer powers are not allowed");
432 return 0;
433 }
434 return static_cast<T>(IntegerPower(base, exp));
435 }
436
437 template <typename T, typename Arg0, typename Arg1>
438 static enable_if_floating_point<T> Call(KernelContext*, T base, T exp, Status*) {
439 return std::pow(base, exp);
440 }
441 };
442
443 struct PowerChecked {
444 template <typename T, typename Arg0, typename Arg1>
445 static enable_if_c_integer<T> Call(KernelContext*, Arg0 base, Arg1 exp, Status* st) {
446 if (exp < 0) {
447 *st = Status::Invalid("integers to negative integer powers are not allowed");
448 return 0;
449 } else if (exp == 0) {
450 return 1;
451 }
452 // left to right O(logn) power with overflow checks
453 bool overflow = false;
454 uint64_t bitmask =
455 1ULL << (63 - BitUtil::CountLeadingZeros(static_cast<uint64_t>(exp)));
456 T pow = 1;
457 while (bitmask) {
458 overflow |= MultiplyWithOverflow(pow, pow, &pow);
459 if (exp & bitmask) {
460 overflow |= MultiplyWithOverflow(pow, base, &pow);
461 }
462 bitmask >>= 1;
463 }
464 if (overflow) {
465 *st = Status::Invalid("overflow");
466 }
467 return pow;
468 }
469
470 template <typename T, typename Arg0, typename Arg1>
471 static enable_if_floating_point<T> Call(KernelContext*, Arg0 base, Arg1 exp, Status*) {
472 static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
473 return std::pow(base, exp);
474 }
475 };
476
477 struct Sign {
478 template <typename T, typename Arg>
479 static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
480 Status*) {
481 return std::isnan(arg) ? arg : ((arg == 0) ? 0 : (std::signbit(arg) ? -1 : 1));
482 }
483
484 template <typename T, typename Arg>
485 static constexpr enable_if_unsigned_c_integer<Arg, T> Call(KernelContext*, Arg arg,
486 Status*) {
487 return (arg > 0) ? 1 : 0;
488 }
489
490 template <typename T, typename Arg>
491 static constexpr enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg,
492 Status*) {
493 return (arg > 0) ? 1 : ((arg == 0) ? 0 : -1);
494 }
495 };
496
497 // Bitwise operations
498
499 struct BitWiseNot {
500 template <typename T, typename Arg>
501 static T Call(KernelContext*, Arg arg, Status*) {
502 return ~arg;
503 }
504 };
505
506 struct BitWiseAnd {
507 template <typename T, typename Arg0, typename Arg1>
508 static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
509 return lhs & rhs;
510 }
511 };
512
513 struct BitWiseOr {
514 template <typename T, typename Arg0, typename Arg1>
515 static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
516 return lhs | rhs;
517 }
518 };
519
520 struct BitWiseXor {
521 template <typename T, typename Arg0, typename Arg1>
522 static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
523 return lhs ^ rhs;
524 }
525 };
526
527 struct ShiftLeft {
528 template <typename T, typename Arg0, typename Arg1>
529 static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
530 using Unsigned = typename std::make_unsigned<Arg0>::type;
531 static_assert(std::is_same<T, Arg0>::value, "");
532 if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
533 return lhs;
534 }
535 return static_cast<T>(static_cast<Unsigned>(lhs) << static_cast<Unsigned>(rhs));
536 }
537 };
538
539 // See SEI CERT C Coding Standard rule INT34-C
540 struct ShiftLeftChecked {
541 template <typename T, typename Arg0, typename Arg1>
542 static enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg0 lhs, Arg1 rhs,
543 Status* st) {
544 static_assert(std::is_same<T, Arg0>::value, "");
545 if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
546 *st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
547 return lhs;
548 }
549 return lhs << rhs;
550 }
551
552 template <typename T, typename Arg0, typename Arg1>
553 static enable_if_signed_c_integer<T> Call(KernelContext*, Arg0 lhs, Arg1 rhs,
554 Status* st) {
555 using Unsigned = typename std::make_unsigned<Arg0>::type;
556 static_assert(std::is_same<T, Arg0>::value, "");
557 if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
558 *st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
559 return lhs;
560 }
561 // In C/C++ left shift of a negative number is undefined (C++11 standard 5.8.2)
562 // Mimic Java/etc. and treat left shift as based on two's complement representation
563 // Assumes two's complement machine
564 return static_cast<T>(static_cast<Unsigned>(lhs) << static_cast<Unsigned>(rhs));
565 }
566 };
567
568 struct ShiftRight {
569 template <typename T, typename Arg0, typename Arg1>
570 static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
571 static_assert(std::is_same<T, Arg0>::value, "");
572 // Logical right shift when Arg0 is unsigned
573 // Arithmetic otherwise (this is implementation-defined but GCC and MSVC document this
574 // as arithmetic right shift)
575 // https://gcc.gnu.org/onlinedocs/gcc/Integers-implementation.html#Integers-implementation
576 // https://docs.microsoft.com/en-us/cpp/cpp/left-shift-and-right-shift-operators-input-and-output?view=msvc-160
577 // Clang doesn't document their behavior.
578 if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
579 return lhs;
580 }
581 return lhs >> rhs;
582 }
583 };
584
585 struct ShiftRightChecked {
586 template <typename T, typename Arg0, typename Arg1>
587 static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status* st) {
588 static_assert(std::is_same<T, Arg0>::value, "");
589 if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
590 *st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
591 return lhs;
592 }
593 return lhs >> rhs;
594 }
595 };
596
597 struct Sin {
598 template <typename T, typename Arg0>
599 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
600 static_assert(std::is_same<T, Arg0>::value, "");
601 return std::sin(val);
602 }
603 };
604
605 struct SinChecked {
606 template <typename T, typename Arg0>
607 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
608 static_assert(std::is_same<T, Arg0>::value, "");
609 if (ARROW_PREDICT_FALSE(std::isinf(val))) {
610 *st = Status::Invalid("domain error");
611 return val;
612 }
613 return std::sin(val);
614 }
615 };
616
617 struct Cos {
618 template <typename T, typename Arg0>
619 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
620 static_assert(std::is_same<T, Arg0>::value, "");
621 return std::cos(val);
622 }
623 };
624
625 struct CosChecked {
626 template <typename T, typename Arg0>
627 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
628 static_assert(std::is_same<T, Arg0>::value, "");
629 if (ARROW_PREDICT_FALSE(std::isinf(val))) {
630 *st = Status::Invalid("domain error");
631 return val;
632 }
633 return std::cos(val);
634 }
635 };
636
637 struct Tan {
638 template <typename T, typename Arg0>
639 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
640 static_assert(std::is_same<T, Arg0>::value, "");
641 return std::tan(val);
642 }
643 };
644
645 struct TanChecked {
646 template <typename T, typename Arg0>
647 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
648 static_assert(std::is_same<T, Arg0>::value, "");
649 if (ARROW_PREDICT_FALSE(std::isinf(val))) {
650 *st = Status::Invalid("domain error");
651 return val;
652 }
653 // Cannot raise range errors (overflow) since PI/2 is not exactly representable
654 return std::tan(val);
655 }
656 };
657
658 struct Asin {
659 template <typename T, typename Arg0>
660 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
661 static_assert(std::is_same<T, Arg0>::value, "");
662 if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) {
663 return std::numeric_limits<T>::quiet_NaN();
664 }
665 return std::asin(val);
666 }
667 };
668
669 struct AsinChecked {
670 template <typename T, typename Arg0>
671 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
672 static_assert(std::is_same<T, Arg0>::value, "");
673 if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) {
674 *st = Status::Invalid("domain error");
675 return val;
676 }
677 return std::asin(val);
678 }
679 };
680
681 struct Acos {
682 template <typename T, typename Arg0>
683 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
684 static_assert(std::is_same<T, Arg0>::value, "");
685 if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) {
686 return std::numeric_limits<T>::quiet_NaN();
687 }
688 return std::acos(val);
689 }
690 };
691
692 struct AcosChecked {
693 template <typename T, typename Arg0>
694 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
695 static_assert(std::is_same<T, Arg0>::value, "");
696 if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) {
697 *st = Status::Invalid("domain error");
698 return val;
699 }
700 return std::acos(val);
701 }
702 };
703
704 struct Atan {
705 template <typename T, typename Arg0>
706 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
707 static_assert(std::is_same<T, Arg0>::value, "");
708 return std::atan(val);
709 }
710 };
711
712 struct Atan2 {
713 template <typename T, typename Arg0, typename Arg1>
714 static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 y, Arg1 x, Status*) {
715 static_assert(std::is_same<T, Arg0>::value, "");
716 static_assert(std::is_same<Arg0, Arg1>::value, "");
717 return std::atan2(y, x);
718 }
719 };
720
721 struct LogNatural {
722 template <typename T, typename Arg>
723 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
724 static_assert(std::is_same<T, Arg>::value, "");
725 if (arg == 0.0) {
726 return -std::numeric_limits<T>::infinity();
727 } else if (arg < 0.0) {
728 return std::numeric_limits<T>::quiet_NaN();
729 }
730 return std::log(arg);
731 }
732 };
733
734 struct LogNaturalChecked {
735 template <typename T, typename Arg>
736 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
737 static_assert(std::is_same<T, Arg>::value, "");
738 if (arg == 0.0) {
739 *st = Status::Invalid("logarithm of zero");
740 return arg;
741 } else if (arg < 0.0) {
742 *st = Status::Invalid("logarithm of negative number");
743 return arg;
744 }
745 return std::log(arg);
746 }
747 };
748
749 struct Log10 {
750 template <typename T, typename Arg>
751 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
752 static_assert(std::is_same<T, Arg>::value, "");
753 if (arg == 0.0) {
754 return -std::numeric_limits<T>::infinity();
755 } else if (arg < 0.0) {
756 return std::numeric_limits<T>::quiet_NaN();
757 }
758 return std::log10(arg);
759 }
760 };
761
762 struct Log10Checked {
763 template <typename T, typename Arg>
764 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
765 static_assert(std::is_same<T, Arg>::value, "");
766 if (arg == 0) {
767 *st = Status::Invalid("logarithm of zero");
768 return arg;
769 } else if (arg < 0) {
770 *st = Status::Invalid("logarithm of negative number");
771 return arg;
772 }
773 return std::log10(arg);
774 }
775 };
776
777 struct Log2 {
778 template <typename T, typename Arg>
779 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
780 static_assert(std::is_same<T, Arg>::value, "");
781 if (arg == 0.0) {
782 return -std::numeric_limits<T>::infinity();
783 } else if (arg < 0.0) {
784 return std::numeric_limits<T>::quiet_NaN();
785 }
786 return std::log2(arg);
787 }
788 };
789
790 struct Log2Checked {
791 template <typename T, typename Arg>
792 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
793 static_assert(std::is_same<T, Arg>::value, "");
794 if (arg == 0.0) {
795 *st = Status::Invalid("logarithm of zero");
796 return arg;
797 } else if (arg < 0.0) {
798 *st = Status::Invalid("logarithm of negative number");
799 return arg;
800 }
801 return std::log2(arg);
802 }
803 };
804
805 struct Log1p {
806 template <typename T, typename Arg>
807 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
808 static_assert(std::is_same<T, Arg>::value, "");
809 if (arg == -1) {
810 return -std::numeric_limits<T>::infinity();
811 } else if (arg < -1) {
812 return std::numeric_limits<T>::quiet_NaN();
813 }
814 return std::log1p(arg);
815 }
816 };
817
818 struct Log1pChecked {
819 template <typename T, typename Arg>
820 static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
821 static_assert(std::is_same<T, Arg>::value, "");
822 if (arg == -1) {
823 *st = Status::Invalid("logarithm of zero");
824 return arg;
825 } else if (arg < -1) {
826 *st = Status::Invalid("logarithm of negative number");
827 return arg;
828 }
829 return std::log1p(arg);
830 }
831 };
832
833 struct Logb {
834 template <typename T, typename Arg0, typename Arg1>
835 static enable_if_floating_point<T> Call(KernelContext*, Arg0 x, Arg1 base, Status*) {
836 static_assert(std::is_same<T, Arg0>::value, "");
837 static_assert(std::is_same<Arg0, Arg1>::value, "");
838 if (x == 0.0) {
839 if (base == 0.0 || base < 0.0) {
840 return std::numeric_limits<T>::quiet_NaN();
841 } else {
842 return -std::numeric_limits<T>::infinity();
843 }
844 } else if (x < 0.0) {
845 return std::numeric_limits<T>::quiet_NaN();
846 }
847 return std::log(x) / std::log(base);
848 }
849 };
850
851 struct LogbChecked {
852 template <typename T, typename Arg0, typename Arg1>
853 static enable_if_floating_point<T> Call(KernelContext*, Arg0 x, Arg1 base, Status* st) {
854 static_assert(std::is_same<T, Arg0>::value, "");
855 static_assert(std::is_same<Arg0, Arg1>::value, "");
856 if (x == 0.0 || base == 0.0) {
857 *st = Status::Invalid("logarithm of zero");
858 return x;
859 } else if (x < 0.0 || base < 0.0) {
860 *st = Status::Invalid("logarithm of negative number");
861 return x;
862 }
863 return std::log(x) / std::log(base);
864 }
865 };
866
867 struct RoundUtil {
868 // Calculate powers of ten with arbitrary integer exponent
869 template <typename T = double>
870 static enable_if_floating_point<T> Pow10(int64_t power) {
871 static constexpr T lut[] = {1e0F, 1e1F, 1e2F, 1e3F, 1e4F, 1e5F, 1e6F, 1e7F,
872 1e8F, 1e9F, 1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F};
873 int64_t lut_size = (sizeof(lut) / sizeof(*lut));
874 int64_t abs_power = std::abs(power);
875 auto pow10 = lut[std::min(abs_power, lut_size - 1)];
876 while (abs_power-- >= lut_size) {
877 pow10 *= 1e1F;
878 }
879 return (power >= 0) ? pow10 : (1 / pow10);
880 }
881 };
882
883 // Specializations of rounding implementations for round kernels
884 template <typename Type, RoundMode>
885 struct RoundImpl;
886
887 template <typename Type>
888 struct RoundImpl<Type, RoundMode::DOWN> {
889 template <typename T = Type>
890 static constexpr enable_if_floating_point<T> Round(const T val) {
891 return std::floor(val);
892 }
893
894 template <typename T = Type>
895 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
896 const T& pow10, const int32_t scale) {
897 (*val) -= remainder;
898 if (remainder.Sign() < 0) {
899 (*val) -= pow10;
900 }
901 }
902 };
903
904 template <typename Type>
905 struct RoundImpl<Type, RoundMode::UP> {
906 template <typename T = Type>
907 static constexpr enable_if_floating_point<T> Round(const T val) {
908 return std::ceil(val);
909 }
910
911 template <typename T = Type>
912 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
913 const T& pow10, const int32_t scale) {
914 (*val) -= remainder;
915 if (remainder.Sign() > 0 && remainder != 0) {
916 (*val) += pow10;
917 }
918 }
919 };
920
921 template <typename Type>
922 struct RoundImpl<Type, RoundMode::TOWARDS_ZERO> {
923 template <typename T = Type>
924 static constexpr enable_if_floating_point<T> Round(const T val) {
925 return std::trunc(val);
926 }
927
928 template <typename T = Type>
929 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
930 const T& pow10, const int32_t scale) {
931 (*val) -= remainder;
932 }
933 };
934
935 template <typename Type>
936 struct RoundImpl<Type, RoundMode::TOWARDS_INFINITY> {
937 template <typename T = Type>
938 static constexpr enable_if_floating_point<T> Round(const T val) {
939 return std::signbit(val) ? std::floor(val) : std::ceil(val);
940 }
941
942 template <typename T = Type>
943 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
944 const T& pow10, const int32_t scale) {
945 (*val) -= remainder;
946 if (remainder.Sign() < 0) {
947 (*val) -= pow10;
948 } else if (remainder.Sign() > 0 && remainder != 0) {
949 (*val) += pow10;
950 }
951 }
952 };
953
954 // NOTE: RoundImpl variants for the HALF_* rounding modes are only
955 // invoked when the fractional part is equal to 0.5 (std::round is invoked
956 // otherwise).
957
958 template <typename Type>
959 struct RoundImpl<Type, RoundMode::HALF_DOWN> {
960 template <typename T = Type>
961 static constexpr enable_if_floating_point<T> Round(const T val) {
962 return RoundImpl<T, RoundMode::DOWN>::Round(val);
963 }
964
965 template <typename T = Type>
966 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
967 const T& pow10, const int32_t scale) {
968 RoundImpl<T, RoundMode::DOWN>::Round(val, remainder, pow10, scale);
969 }
970 };
971
972 template <typename Type>
973 struct RoundImpl<Type, RoundMode::HALF_UP> {
974 template <typename T = Type>
975 static constexpr enable_if_floating_point<T> Round(const T val) {
976 return RoundImpl<T, RoundMode::UP>::Round(val);
977 }
978
979 template <typename T = Type>
980 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
981 const T& pow10, const int32_t scale) {
982 RoundImpl<T, RoundMode::UP>::Round(val, remainder, pow10, scale);
983 }
984 };
985
986 template <typename Type>
987 struct RoundImpl<Type, RoundMode::HALF_TOWARDS_ZERO> {
988 template <typename T = Type>
989 static constexpr enable_if_floating_point<T> Round(const T val) {
990 return RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(val);
991 }
992
993 template <typename T = Type>
994 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
995 const T& pow10, const int32_t scale) {
996 RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(val, remainder, pow10, scale);
997 }
998 };
999
1000 template <typename Type>
1001 struct RoundImpl<Type, RoundMode::HALF_TOWARDS_INFINITY> {
1002 template <typename T = Type>
1003 static constexpr enable_if_floating_point<T> Round(const T val) {
1004 return RoundImpl<T, RoundMode::TOWARDS_INFINITY>::Round(val);
1005 }
1006
1007 template <typename T = Type>
1008 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
1009 const T& pow10, const int32_t scale) {
1010 RoundImpl<T, RoundMode::TOWARDS_INFINITY>::Round(val, remainder, pow10, scale);
1011 }
1012 };
1013
1014 template <typename Type>
1015 struct RoundImpl<Type, RoundMode::HALF_TO_EVEN> {
1016 template <typename T = Type>
1017 static constexpr enable_if_floating_point<T> Round(const T val) {
1018 return std::round(val * T(0.5)) * 2;
1019 }
1020
1021 template <typename T = Type>
1022 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
1023 const T& pow10, const int32_t scale) {
1024 auto scaled = val->ReduceScaleBy(scale, /*round=*/false);
1025 if (scaled.low_bits() % 2 != 0) {
1026 scaled += remainder.Sign() >= 0 ? 1 : -1;
1027 }
1028 *val = scaled.IncreaseScaleBy(scale);
1029 }
1030 };
1031
1032 template <typename Type>
1033 struct RoundImpl<Type, RoundMode::HALF_TO_ODD> {
1034 template <typename T = Type>
1035 static constexpr enable_if_floating_point<T> Round(const T val) {
1036 return std::floor(val * T(0.5)) + std::ceil(val * T(0.5));
1037 }
1038
1039 template <typename T = Type>
1040 static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
1041 const T& pow10, const int32_t scale) {
1042 auto scaled = val->ReduceScaleBy(scale, /*round=*/false);
1043 if (scaled.low_bits() % 2 == 0) {
1044 scaled += remainder.Sign() ? 1 : -1;
1045 }
1046 *val = scaled.IncreaseScaleBy(scale);
1047 }
1048 };
1049
1050 // Specializations of kernel state for round kernels
1051 template <typename OptionsType>
1052 struct RoundOptionsWrapper;
1053
1054 template <>
1055 struct RoundOptionsWrapper<RoundOptions> : public OptionsWrapper<RoundOptions> {
1056 using OptionsType = RoundOptions;
1057 using State = RoundOptionsWrapper<OptionsType>;
1058 double pow10;
1059
1060 explicit RoundOptionsWrapper(OptionsType options) : OptionsWrapper(std::move(options)) {
1061 // Only positive exponents for powers of 10 are used because combining
1062 // multiply and division operations produced more stable rounding than
1063 // using multiply-only. Refer to NumPy's round implementation:
1064 // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589
1065 pow10 = RoundUtil::Pow10(std::abs(options.ndigits));
1066 }
1067
1068 static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
1069 const KernelInitArgs& args) {
1070 if (auto options = static_cast<const OptionsType*>(args.options)) {
1071 return ::arrow::internal::make_unique<State>(*options);
1072 }
1073 return Status::Invalid(
1074 "Attempted to initialize KernelState from null FunctionOptions");
1075 }
1076 };
1077
1078 template <>
1079 struct RoundOptionsWrapper<RoundToMultipleOptions>
1080 : public OptionsWrapper<RoundToMultipleOptions> {
1081 using OptionsType = RoundToMultipleOptions;
1082 using State = RoundOptionsWrapper<OptionsType>;
1083 using OptionsWrapper::OptionsWrapper;
1084
1085 static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
1086 const KernelInitArgs& args) {
1087 std::unique_ptr<State> state;
1088 if (auto options = static_cast<const OptionsType*>(args.options)) {
1089 state = ::arrow::internal::make_unique<State>(*options);
1090 } else {
1091 return Status::Invalid(
1092 "Attempted to initialize KernelState from null FunctionOptions");
1093 }
1094
1095 auto options = Get(*state);
1096 const auto& type = *args.inputs[0].type;
1097 if (!options.multiple || !options.multiple->is_valid) {
1098 return Status::Invalid("Rounding multiple must be non-null and valid");
1099 }
1100 if (is_floating(type.id())) {
1101 switch (options.multiple->type->id()) {
1102 case Type::FLOAT: {
1103 if (UnboxScalar<FloatType>::Unbox(*options.multiple) < 0) {
1104 return Status::Invalid("Rounding multiple must be positive");
1105 }
1106 break;
1107 }
1108 case Type::DOUBLE: {
1109 if (UnboxScalar<DoubleType>::Unbox(*options.multiple) < 0) {
1110 return Status::Invalid("Rounding multiple must be positive");
1111 }
1112 break;
1113 }
1114 case Type::HALF_FLOAT:
1115 return Status::NotImplemented("Half-float values are not supported");
1116 default:
1117 return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
1118 *options.multiple->type);
1119 }
1120 } else {
1121 DCHECK(is_decimal(type.id()));
1122 if (!type.Equals(*options.multiple->type)) {
1123 return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
1124 *options.multiple->type);
1125 }
1126 switch (options.multiple->type->id()) {
1127 case Type::DECIMAL128: {
1128 if (UnboxScalar<Decimal128Type>::Unbox(*options.multiple) <= 0) {
1129 return Status::Invalid("Rounding multiple must be positive");
1130 }
1131 break;
1132 }
1133 case Type::DECIMAL256: {
1134 if (UnboxScalar<Decimal256Type>::Unbox(*options.multiple) <= 0) {
1135 return Status::Invalid("Rounding multiple must be positive");
1136 }
1137 break;
1138 }
1139 default:
1140 // This shouldn't happen
1141 return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
1142 *options.multiple->type);
1143 }
1144 }
1145 return std::move(state);
1146 }
1147 };
1148
1149 template <typename ArrowType, RoundMode RndMode, typename Enable = void>
1150 struct Round {
1151 using CType = typename TypeTraits<ArrowType>::CType;
1152 using State = RoundOptionsWrapper<RoundOptions>;
1153
1154 CType pow10;
1155 int64_t ndigits;
1156
1157 explicit Round(const State& state, const DataType& out_ty)
1158 : pow10(static_cast<CType>(state.pow10)), ndigits(state.options.ndigits) {}
1159
1160 template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
1161 enable_if_floating_point<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
1162 // Do not process Inf or NaN because they will trigger the overflow error at end of
1163 // function.
1164 if (!std::isfinite(arg)) {
1165 return arg;
1166 }
1167 auto round_val = ndigits >= 0 ? (arg * pow10) : (arg / pow10);
1168 auto frac = round_val - std::floor(round_val);
1169 if (frac != T(0)) {
1170 // Use std::round() if in tie-breaking mode and scaled value is not 0.5.
1171 if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) {
1172 round_val = std::round(round_val);
1173 } else {
1174 round_val = RoundImpl<CType, RndMode>::Round(round_val);
1175 }
1176 // Equality check is ommitted so that the common case of 10^0 (integer rounding)
1177 // uses multiply-only
1178 round_val = ndigits > 0 ? (round_val / pow10) : (round_val * pow10);
1179 if (!std::isfinite(round_val)) {
1180 *st = Status::Invalid("overflow occurred during rounding");
1181 return arg;
1182 }
1183 } else {
1184 // If scaled value is an integer, then no rounding is needed.
1185 round_val = arg;
1186 }
1187 return round_val;
1188 }
1189 };
1190
1191 template <typename ArrowType, RoundMode kRoundMode>
1192 struct Round<ArrowType, kRoundMode, enable_if_decimal<ArrowType>> {
1193 using CType = typename TypeTraits<ArrowType>::CType;
1194 using State = RoundOptionsWrapper<RoundOptions>;
1195
1196 const ArrowType& ty;
1197 int64_t ndigits;
1198 int32_t pow;
1199 // pow10 is "1" for the given decimal scale. Similarly half_pow10 is "0.5".
1200 CType pow10, half_pow10, neg_half_pow10;
1201
1202 explicit Round(const State& state, const DataType& out_ty)
1203 : Round(state.options.ndigits, out_ty) {}
1204
1205 explicit Round(int64_t ndigits, const DataType& out_ty)
1206 : ty(checked_cast<const ArrowType&>(out_ty)),
1207 ndigits(ndigits),
1208 pow(static_cast<int32_t>(ty.scale() - ndigits)) {
1209 if (pow >= ty.precision() || pow < 0) {
1210 pow10 = half_pow10 = neg_half_pow10 = 0;
1211 } else {
1212 pow10 = CType::GetScaleMultiplier(pow);
1213 half_pow10 = CType::GetHalfScaleMultiplier(pow);
1214 neg_half_pow10 = -half_pow10;
1215 }
1216 }
1217
1218 template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
1219 enable_if_decimal_value<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
1220 if (pow >= ty.precision()) {
1221 *st = Status::Invalid("Rounding to ", ndigits,
1222 " digits will not fit in precision of ", ty);
1223 return arg;
1224 } else if (pow < 0) {
1225 // no-op, copy output to input
1226 return arg;
1227 }
1228
1229 std::pair<CType, CType> pair;
1230 *st = arg.Divide(pow10).Value(&pair);
1231 if (!st->ok()) return arg;
1232 // The remainder is effectively the scaled fractional part after division.
1233 const auto& remainder = pair.second;
1234 if (remainder == 0) return arg;
1235 if (kRoundMode >= RoundMode::HALF_DOWN) {
1236 if (remainder == half_pow10 || remainder == neg_half_pow10) {
1237 // On the halfway point, use tiebreaker
1238 RoundImpl<CType, kRoundMode>::Round(&arg, remainder, pow10, pow);
1239 } else if (remainder.Sign() >= 0) {
1240 // Positive, round up/down
1241 arg -= remainder;
1242 if (remainder > half_pow10) {
1243 arg += pow10;
1244 }
1245 } else {
1246 // Negative, round up/down
1247 arg -= remainder;
1248 if (remainder < neg_half_pow10) {
1249 arg -= pow10;
1250 }
1251 }
1252 } else {
1253 RoundImpl<CType, kRoundMode>::Round(&arg, remainder, pow10, pow);
1254 }
1255 if (!arg.FitsInPrecision(ty.precision())) {
1256 *st = Status::Invalid("Rounded value ", arg.ToString(ty.scale()),
1257 " does not fit in precision of ", ty);
1258 return 0;
1259 }
1260 return arg;
1261 }
1262 };
1263
1264 template <typename DecimalType, RoundMode kMode, int32_t kDigits>
1265 Status FixedRoundDecimalExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
1266 using Op = Round<DecimalType, kMode>;
1267 return ScalarUnaryNotNullStateful<DecimalType, DecimalType, Op>(
1268 Op(kDigits, *out->type()))
1269 .Exec(ctx, batch, out);
1270 }
1271
1272 template <typename ArrowType, RoundMode kRoundMode, typename Enable = void>
1273 struct RoundToMultiple {
1274 using CType = typename TypeTraits<ArrowType>::CType;
1275 using State = RoundOptionsWrapper<RoundToMultipleOptions>;
1276
1277 CType multiple;
1278
1279 explicit RoundToMultiple(const State& state, const DataType& out_ty) {
1280 const auto& options = state.options;
1281 DCHECK(options.multiple);
1282 DCHECK(options.multiple->is_valid);
1283 DCHECK(is_floating(options.multiple->type->id()));
1284 switch (options.multiple->type->id()) {
1285 case Type::FLOAT:
1286 multiple = static_cast<CType>(UnboxScalar<FloatType>::Unbox(*options.multiple));
1287 break;
1288 case Type::DOUBLE:
1289 multiple = static_cast<CType>(UnboxScalar<DoubleType>::Unbox(*options.multiple));
1290 break;
1291 default:
1292 DCHECK(false);
1293 }
1294 }
1295
1296 template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
1297 enable_if_floating_point<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
1298 // Do not process Inf or NaN because they will trigger the overflow error at end of
1299 // function.
1300 if (!std::isfinite(arg)) {
1301 return arg;
1302 }
1303 auto round_val = arg / multiple;
1304 auto frac = round_val - std::floor(round_val);
1305 if (frac != T(0)) {
1306 // Use std::round() if in tie-breaking mode and scaled value is not 0.5.
1307 if ((kRoundMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) {
1308 round_val = std::round(round_val);
1309 } else {
1310 round_val = RoundImpl<CType, kRoundMode>::Round(round_val);
1311 }
1312 round_val *= multiple;
1313 if (!std::isfinite(round_val)) {
1314 *st = Status::Invalid("overflow occurred during rounding");
1315 return arg;
1316 }
1317 } else {
1318 // If scaled value is an integer, then no rounding is needed.
1319 round_val = arg;
1320 }
1321 return round_val;
1322 }
1323 };
1324
1325 template <typename ArrowType, RoundMode kRoundMode>
1326 struct RoundToMultiple<ArrowType, kRoundMode, enable_if_decimal<ArrowType>> {
1327 using CType = typename TypeTraits<ArrowType>::CType;
1328 using State = RoundOptionsWrapper<RoundToMultipleOptions>;
1329
1330 const ArrowType& ty;
1331 CType multiple, half_multiple, neg_half_multiple;
1332 bool has_halfway_point;
1333
1334 explicit RoundToMultiple(const State& state, const DataType& out_ty)
1335 : ty(checked_cast<const ArrowType&>(out_ty)) {
1336 const auto& options = state.options;
1337 DCHECK(options.multiple);
1338 DCHECK(options.multiple->is_valid);
1339 DCHECK(options.multiple->type->Equals(out_ty));
1340 multiple = UnboxScalar<ArrowType>::Unbox(*options.multiple);
1341 half_multiple = multiple;
1342 half_multiple /= 2;
1343 neg_half_multiple = -half_multiple;
1344 has_halfway_point = multiple.low_bits() % 2 == 0;
1345 }
1346
1347 template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
1348 enable_if_decimal_value<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
1349 std::pair<CType, CType> pair;
1350 *st = arg.Divide(multiple).Value(&pair);
1351 if (!st->ok()) return arg;
1352 const auto& remainder = pair.second;
1353 if (remainder == 0) return arg;
1354 if (kRoundMode >= RoundMode::HALF_DOWN) {
1355 if (has_halfway_point &&
1356 (remainder == half_multiple || remainder == neg_half_multiple)) {
1357 // On the halfway point, use tiebreaker
1358 // Manually implement rounding since we're not actually rounding a
1359 // decimal value, but rather manipulating the multiple
1360 switch (kRoundMode) {
1361 case RoundMode::HALF_DOWN:
1362 if (remainder.Sign() < 0) pair.first -= 1;
1363 break;
1364 case RoundMode::HALF_UP:
1365 if (remainder.Sign() >= 0) pair.first += 1;
1366 break;
1367 case RoundMode::HALF_TOWARDS_ZERO:
1368 // Do nothing
1369 break;
1370 case RoundMode::HALF_TOWARDS_INFINITY:
1371 if (remainder.Sign() >= 0) {
1372 pair.first += 1;
1373 } else {
1374 pair.first -= 1;
1375 }
1376 break;
1377 case RoundMode::HALF_TO_EVEN:
1378 if (pair.first.low_bits() % 2 != 0) {
1379 pair.first += remainder.Sign() >= 0 ? 1 : -1;
1380 }
1381 break;
1382 case RoundMode::HALF_TO_ODD:
1383 if (pair.first.low_bits() % 2 == 0) {
1384 pair.first += remainder.Sign() >= 0 ? 1 : -1;
1385 }
1386 break;
1387 default:
1388 DCHECK(false);
1389 }
1390 } else if (remainder.Sign() >= 0) {
1391 // Positive, round up/down
1392 if (remainder > half_multiple) {
1393 pair.first += 1;
1394 }
1395 } else {
1396 // Negative, round up/down
1397 if (remainder < neg_half_multiple) {
1398 pair.first -= 1;
1399 }
1400 }
1401 } else {
1402 // Manually implement rounding since we're not actually rounding a
1403 // decimal value, but rather manipulating the multiple
1404 switch (kRoundMode) {
1405 case RoundMode::DOWN:
1406 if (remainder.Sign() < 0) pair.first -= 1;
1407 break;
1408 case RoundMode::UP:
1409 if (remainder.Sign() >= 0) pair.first += 1;
1410 break;
1411 case RoundMode::TOWARDS_ZERO:
1412 // Do nothing
1413 break;
1414 case RoundMode::TOWARDS_INFINITY:
1415 if (remainder.Sign() >= 0) {
1416 pair.first += 1;
1417 } else {
1418 pair.first -= 1;
1419 }
1420 break;
1421 default:
1422 DCHECK(false);
1423 }
1424 }
1425 CType round_val = pair.first * multiple;
1426 if (!round_val.FitsInPrecision(ty.precision())) {
1427 *st = Status::Invalid("Rounded value ", round_val.ToString(ty.scale()),
1428 " does not fit in precision of ", ty);
1429 return 0;
1430 }
1431 return round_val;
1432 }
1433 };
1434
1435 struct Floor {
1436 template <typename T, typename Arg>
1437 static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
1438 Status*) {
1439 static_assert(std::is_same<T, Arg>::value, "");
1440 return RoundImpl<T, RoundMode::DOWN>::Round(arg);
1441 }
1442 };
1443
1444 struct Ceil {
1445 template <typename T, typename Arg>
1446 static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
1447 Status*) {
1448 static_assert(std::is_same<T, Arg>::value, "");
1449 return RoundImpl<T, RoundMode::UP>::Round(arg);
1450 }
1451 };
1452
1453 struct Trunc {
1454 template <typename T, typename Arg>
1455 static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
1456 Status*) {
1457 static_assert(std::is_same<T, Arg>::value, "");
1458 return RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(arg);
1459 }
1460 };
1461
1462 // Generate a kernel given an arithmetic functor
1463 template <template <typename... Args> class KernelGenerator, typename Op>
1464 ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
1465 switch (get_id.id) {
1466 case Type::INT8:
1467 return KernelGenerator<Int8Type, Int8Type, Op>::Exec;
1468 case Type::UINT8:
1469 return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
1470 case Type::INT16:
1471 return KernelGenerator<Int16Type, Int16Type, Op>::Exec;
1472 case Type::UINT16:
1473 return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
1474 case Type::INT32:
1475 return KernelGenerator<Int32Type, Int32Type, Op>::Exec;
1476 case Type::UINT32:
1477 return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
1478 case Type::INT64:
1479 case Type::TIMESTAMP:
1480 return KernelGenerator<Int64Type, Int64Type, Op>::Exec;
1481 case Type::UINT64:
1482 return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
1483 case Type::FLOAT:
1484 return KernelGenerator<FloatType, FloatType, Op>::Exec;
1485 case Type::DOUBLE:
1486 return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
1487 default:
1488 DCHECK(false);
1489 return ExecFail;
1490 }
1491 }
1492
1493 // Generate a kernel given a bitwise arithmetic functor. Assumes the
1494 // functor treats all integer types of equal width identically
1495 template <template <typename... Args> class KernelGenerator, typename Op>
1496 ArrayKernelExec TypeAgnosticBitWiseExecFromOp(detail::GetTypeId get_id) {
1497 switch (get_id.id) {
1498 case Type::INT8:
1499 case Type::UINT8:
1500 return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
1501 case Type::INT16:
1502 case Type::UINT16:
1503 return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
1504 case Type::INT32:
1505 case Type::UINT32:
1506 return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
1507 case Type::INT64:
1508 case Type::UINT64:
1509 return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
1510 default:
1511 DCHECK(false);
1512 return ExecFail;
1513 }
1514 }
1515
1516 template <template <typename... Args> class KernelGenerator, typename Op>
1517 ArrayKernelExec ShiftExecFromOp(detail::GetTypeId get_id) {
1518 switch (get_id.id) {
1519 case Type::INT8:
1520 return KernelGenerator<Int8Type, Int8Type, Op>::Exec;
1521 case Type::UINT8:
1522 return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
1523 case Type::INT16:
1524 return KernelGenerator<Int16Type, Int16Type, Op>::Exec;
1525 case Type::UINT16:
1526 return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
1527 case Type::INT32:
1528 return KernelGenerator<Int32Type, Int32Type, Op>::Exec;
1529 case Type::UINT32:
1530 return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
1531 case Type::INT64:
1532 return KernelGenerator<Int64Type, Int64Type, Op>::Exec;
1533 case Type::UINT64:
1534 return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
1535 default:
1536 DCHECK(false);
1537 return ExecFail;
1538 }
1539 }
1540
1541 template <template <typename... Args> class KernelGenerator, typename Op>
1542 ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) {
1543 switch (get_id.id) {
1544 case Type::FLOAT:
1545 return KernelGenerator<FloatType, FloatType, Op>::Exec;
1546 case Type::DOUBLE:
1547 return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
1548 default:
1549 DCHECK(false);
1550 return ExecFail;
1551 }
1552 }
1553
1554 // resolve decimal binary operation output type per *casted* args
1555 template <typename OutputGetter>
1556 Result<ValueDescr> ResolveDecimalBinaryOperationOutput(
1557 const std::vector<ValueDescr>& args, OutputGetter&& getter) {
1558 // casted args should be same size decimals
1559 auto left_type = checked_cast<const DecimalType*>(args[0].type.get());
1560 auto right_type = checked_cast<const DecimalType*>(args[1].type.get());
1561 DCHECK_EQ(left_type->id(), right_type->id());
1562
1563 int32_t precision, scale;
1564 std::tie(precision, scale) = getter(left_type->precision(), left_type->scale(),
1565 right_type->precision(), right_type->scale());
1566 ARROW_ASSIGN_OR_RAISE(auto type, DecimalType::Make(left_type->id(), precision, scale));
1567 return ValueDescr(std::move(type), GetBroadcastShape(args));
1568 }
1569
1570 Result<ValueDescr> ResolveDecimalAdditionOrSubtractionOutput(
1571 KernelContext*, const std::vector<ValueDescr>& args) {
1572 return ResolveDecimalBinaryOperationOutput(
1573 args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
1574 DCHECK_EQ(s1, s2);
1575 const int32_t scale = s1;
1576 const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1;
1577 return std::make_pair(precision, scale);
1578 });
1579 }
1580
1581 Result<ValueDescr> ResolveDecimalMultiplicationOutput(
1582 KernelContext*, const std::vector<ValueDescr>& args) {
1583 return ResolveDecimalBinaryOperationOutput(
1584 args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
1585 const int32_t scale = s1 + s2;
1586 const int32_t precision = p1 + p2 + 1;
1587 return std::make_pair(precision, scale);
1588 });
1589 }
1590
1591 Result<ValueDescr> ResolveDecimalDivisionOutput(KernelContext*,
1592 const std::vector<ValueDescr>& args) {
1593 return ResolveDecimalBinaryOperationOutput(
1594 args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
1595 DCHECK_GE(s1, s2);
1596 const int32_t scale = s1 - s2;
1597 const int32_t precision = p1;
1598 return std::make_pair(precision, scale);
1599 });
1600 }
1601
1602 template <typename Op>
1603 void AddDecimalBinaryKernels(const std::string& name,
1604 std::shared_ptr<ScalarFunction>* func) {
1605 OutputType out_type(null());
1606 const std::string op = name.substr(0, name.find("_"));
1607 if (op == "add" || op == "subtract") {
1608 out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput);
1609 } else if (op == "multiply") {
1610 out_type = OutputType(ResolveDecimalMultiplicationOutput);
1611 } else if (op == "divide") {
1612 out_type = OutputType(ResolveDecimalDivisionOutput);
1613 } else {
1614 DCHECK(false);
1615 }
1616
1617 auto in_type128 = InputType(Type::DECIMAL128);
1618 auto in_type256 = InputType(Type::DECIMAL256);
1619 auto exec128 = ScalarBinaryNotNullEqualTypes<Decimal128Type, Decimal128Type, Op>::Exec;
1620 auto exec256 = ScalarBinaryNotNullEqualTypes<Decimal256Type, Decimal256Type, Op>::Exec;
1621 DCHECK_OK((*func)->AddKernel({in_type128, in_type128}, out_type, exec128));
1622 DCHECK_OK((*func)->AddKernel({in_type256, in_type256}, out_type, exec256));
1623 }
1624
1625 // Generate a kernel given an arithmetic functor
1626 template <template <typename...> class KernelGenerator, typename OutType, typename Op>
1627 ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id) {
1628 switch (get_id.id) {
1629 case Type::INT8:
1630 return KernelGenerator<OutType, Int8Type, Op>::Exec;
1631 case Type::UINT8:
1632 return KernelGenerator<OutType, UInt8Type, Op>::Exec;
1633 case Type::INT16:
1634 return KernelGenerator<OutType, Int16Type, Op>::Exec;
1635 case Type::UINT16:
1636 return KernelGenerator<OutType, UInt16Type, Op>::Exec;
1637 case Type::INT32:
1638 return KernelGenerator<OutType, Int32Type, Op>::Exec;
1639 case Type::UINT32:
1640 return KernelGenerator<OutType, UInt32Type, Op>::Exec;
1641 case Type::INT64:
1642 case Type::TIMESTAMP:
1643 return KernelGenerator<OutType, Int64Type, Op>::Exec;
1644 case Type::UINT64:
1645 return KernelGenerator<OutType, UInt64Type, Op>::Exec;
1646 case Type::FLOAT:
1647 return KernelGenerator<FloatType, FloatType, Op>::Exec;
1648 case Type::DOUBLE:
1649 return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
1650 default:
1651 DCHECK(false);
1652 return ExecFail;
1653 }
1654 }
1655
1656 struct ArithmeticFunction : ScalarFunction {
1657 using ScalarFunction::ScalarFunction;
1658
1659 Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
1660 RETURN_NOT_OK(CheckArity(*values));
1661
1662 RETURN_NOT_OK(CheckDecimals(values));
1663
1664 using arrow::compute::detail::DispatchExactImpl;
1665 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
1666
1667 EnsureDictionaryDecoded(values);
1668
1669 // Only promote types for binary functions
1670 if (values->size() == 2) {
1671 ReplaceNullWithOtherType(values);
1672
1673 if (auto type = CommonNumeric(*values)) {
1674 ReplaceTypes(type, values);
1675 }
1676 }
1677
1678 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
1679 return arrow::compute::detail::NoMatchingKernel(this, *values);
1680 }
1681
1682 Status CheckDecimals(std::vector<ValueDescr>* values) const {
1683 if (!HasDecimal(*values)) return Status::OK();
1684
1685 if (values->size() == 2) {
1686 // "add_checked" -> "add"
1687 const auto func_name = name();
1688 const std::string op = func_name.substr(0, func_name.find("_"));
1689 if (op == "add" || op == "subtract") {
1690 return CastBinaryDecimalArgs(DecimalPromotion::kAdd, values);
1691 } else if (op == "multiply") {
1692 return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, values);
1693 } else if (op == "divide") {
1694 return CastBinaryDecimalArgs(DecimalPromotion::kDivide, values);
1695 } else {
1696 return Status::Invalid("Invalid decimal function: ", func_name);
1697 }
1698 }
1699 return Status::OK();
1700 }
1701 };
1702
1703 /// An ArithmeticFunction that promotes only integer arguments to double.
1704 struct ArithmeticIntegerToFloatingPointFunction : public ArithmeticFunction {
1705 using ArithmeticFunction::ArithmeticFunction;
1706
1707 Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
1708 RETURN_NOT_OK(CheckArity(*values));
1709 RETURN_NOT_OK(CheckDecimals(values));
1710
1711 using arrow::compute::detail::DispatchExactImpl;
1712 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
1713
1714 EnsureDictionaryDecoded(values);
1715
1716 if (values->size() == 2) {
1717 ReplaceNullWithOtherType(values);
1718 }
1719
1720 for (auto& descr : *values) {
1721 if (is_integer(descr.type->id())) {
1722 descr.type = float64();
1723 }
1724 }
1725 if (auto type = CommonNumeric(*values)) {
1726 ReplaceTypes(type, values);
1727 }
1728
1729 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
1730 return arrow::compute::detail::NoMatchingKernel(this, *values);
1731 }
1732 };
1733
1734 /// An ArithmeticFunction that promotes integer arguments to double.
1735 struct ArithmeticFloatingPointFunction : public ArithmeticFunction {
1736 using ArithmeticFunction::ArithmeticFunction;
1737
1738 Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
1739 RETURN_NOT_OK(CheckArity(*values));
1740 RETURN_NOT_OK(CheckDecimals(values));
1741
1742 using arrow::compute::detail::DispatchExactImpl;
1743 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
1744
1745 EnsureDictionaryDecoded(values);
1746
1747 if (values->size() == 2) {
1748 ReplaceNullWithOtherType(values);
1749 }
1750
1751 for (auto& descr : *values) {
1752 if (is_integer(descr.type->id())) {
1753 descr.type = float64();
1754 }
1755 }
1756 if (auto type = CommonNumeric(*values)) {
1757 ReplaceTypes(type, values);
1758 }
1759
1760 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
1761 return arrow::compute::detail::NoMatchingKernel(this, *values);
1762 }
1763 };
1764
1765 // A scalar kernel that ignores (assumed all-null) inputs and returns null.
1766 Status NullToNullExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
1767 return Status::OK();
1768 }
1769
1770 void AddNullExec(ScalarFunction* func) {
1771 std::vector<InputType> input_types(func->arity().num_args, InputType(Type::NA));
1772 DCHECK_OK(func->AddKernel(std::move(input_types), OutputType(null()), NullToNullExec));
1773 }
1774
1775 template <typename Op>
1776 std::shared_ptr<ScalarFunction> MakeArithmeticFunction(std::string name,
1777 const FunctionDoc* doc) {
1778 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
1779 for (const auto& ty : NumericTypes()) {
1780 auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Op>(ty);
1781 DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
1782 }
1783 AddNullExec(func.get());
1784 return func;
1785 }
1786
1787 // Like MakeArithmeticFunction, but for arithmetic ops that need to run
1788 // only on non-null output.
1789 template <typename Op>
1790 std::shared_ptr<ScalarFunction> MakeArithmeticFunctionNotNull(std::string name,
1791 const FunctionDoc* doc) {
1792 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
1793 for (const auto& ty : NumericTypes()) {
1794 auto exec = ArithmeticExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
1795 DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
1796 }
1797 AddNullExec(func.get());
1798 return func;
1799 }
1800
1801 template <typename Op>
1802 std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunction(std::string name,
1803 const FunctionDoc* doc) {
1804 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
1805 for (const auto& ty : NumericTypes()) {
1806 auto exec = ArithmeticExecFromOp<ScalarUnary, Op>(ty);
1807 DCHECK_OK(func->AddKernel({ty}, ty, exec));
1808 }
1809 AddNullExec(func.get());
1810 return func;
1811 }
1812
1813 // Like MakeUnaryArithmeticFunction, but for unary arithmetic ops with a fixed
1814 // output type for integral inputs.
1815 template <typename Op, typename IntOutType>
1816 std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionWithFixedIntOutType(
1817 std::string name, const FunctionDoc* doc) {
1818 auto int_out_ty = TypeTraits<IntOutType>::type_singleton();
1819 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
1820 for (const auto& ty : NumericTypes()) {
1821 auto out_ty = arrow::is_floating(ty->id()) ? ty : int_out_ty;
1822 auto exec = GenerateArithmeticWithFixedIntOutType<ScalarUnary, IntOutType, Op>(ty);
1823 DCHECK_OK(func->AddKernel({ty}, out_ty, exec));
1824 }
1825 AddNullExec(func.get());
1826 return func;
1827 }
1828
1829 // Like MakeUnaryArithmeticFunction, but for arithmetic ops that need to run
1830 // only on non-null output.
1831 template <typename Op>
1832 std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionNotNull(
1833 std::string name, const FunctionDoc* doc) {
1834 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
1835 for (const auto& ty : NumericTypes()) {
1836 auto exec = ArithmeticExecFromOp<ScalarUnaryNotNull, Op>(ty);
1837 DCHECK_OK(func->AddKernel({ty}, ty, exec));
1838 }
1839 AddNullExec(func.get());
1840 return func;
1841 }
1842
1843 // Exec the round kernel for the given types
1844 template <typename Type, typename OptionsType,
1845 template <typename, RoundMode, typename...> class OpImpl>
1846 Status ExecRound(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
1847 using State = RoundOptionsWrapper<OptionsType>;
1848 const auto& state = static_cast<const State&>(*ctx->state());
1849 switch (state.options.round_mode) {
1850 case RoundMode::DOWN: {
1851 using Op = OpImpl<Type, RoundMode::DOWN>;
1852 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1853 .Exec(ctx, batch, out);
1854 }
1855 case RoundMode::UP: {
1856 using Op = OpImpl<Type, RoundMode::UP>;
1857 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1858 .Exec(ctx, batch, out);
1859 }
1860 case RoundMode::TOWARDS_ZERO: {
1861 using Op = OpImpl<Type, RoundMode::TOWARDS_ZERO>;
1862 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1863 .Exec(ctx, batch, out);
1864 }
1865 case RoundMode::TOWARDS_INFINITY: {
1866 using Op = OpImpl<Type, RoundMode::TOWARDS_INFINITY>;
1867 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1868 .Exec(ctx, batch, out);
1869 }
1870 case RoundMode::HALF_DOWN: {
1871 using Op = OpImpl<Type, RoundMode::HALF_DOWN>;
1872 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1873 .Exec(ctx, batch, out);
1874 }
1875 case RoundMode::HALF_UP: {
1876 using Op = OpImpl<Type, RoundMode::HALF_UP>;
1877 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1878 .Exec(ctx, batch, out);
1879 }
1880 case RoundMode::HALF_TOWARDS_ZERO: {
1881 using Op = OpImpl<Type, RoundMode::HALF_TOWARDS_ZERO>;
1882 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1883 .Exec(ctx, batch, out);
1884 }
1885 case RoundMode::HALF_TOWARDS_INFINITY: {
1886 using Op = OpImpl<Type, RoundMode::HALF_TOWARDS_INFINITY>;
1887 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1888 .Exec(ctx, batch, out);
1889 }
1890 case RoundMode::HALF_TO_EVEN: {
1891 using Op = OpImpl<Type, RoundMode::HALF_TO_EVEN>;
1892 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1893 .Exec(ctx, batch, out);
1894 }
1895 case RoundMode::HALF_TO_ODD: {
1896 using Op = OpImpl<Type, RoundMode::HALF_TO_ODD>;
1897 return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
1898 .Exec(ctx, batch, out);
1899 }
1900 }
1901 DCHECK(false);
1902 return Status::NotImplemented(
1903 "Internal implementation error: round mode not implemented: ",
1904 state.options.ToString());
1905 }
1906
1907 // Like MakeUnaryArithmeticFunction, but for unary rounding functions that control
1908 // kernel dispatch based on RoundMode, only on non-null output.
1909 template <template <typename, RoundMode, typename...> class Op, typename OptionsType>
1910 std::shared_ptr<ScalarFunction> MakeUnaryRoundFunction(std::string name,
1911 const FunctionDoc* doc) {
1912 using State = RoundOptionsWrapper<OptionsType>;
1913 static const OptionsType kDefaultOptions = OptionsType::Defaults();
1914 auto func = std::make_shared<ArithmeticIntegerToFloatingPointFunction>(
1915 name, Arity::Unary(), doc, &kDefaultOptions);
1916 for (const auto& ty : {float32(), float64(), decimal128(1, 0), decimal256(1, 0)}) {
1917 auto type_id = ty->id();
1918 auto exec = [type_id](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
1919 switch (type_id) {
1920 case Type::FLOAT:
1921 return ExecRound<FloatType, OptionsType, Op>(ctx, batch, out);
1922 case Type::DOUBLE:
1923 return ExecRound<DoubleType, OptionsType, Op>(ctx, batch, out);
1924 case Type::DECIMAL128:
1925 return ExecRound<Decimal128Type, OptionsType, Op>(ctx, batch, out);
1926 case Type::DECIMAL256:
1927 return ExecRound<Decimal256Type, OptionsType, Op>(ctx, batch, out);
1928 default: {
1929 DCHECK(false);
1930 return ExecFail(ctx, batch, out);
1931 }
1932 }
1933 };
1934 DCHECK_OK(func->AddKernel(
1935 {InputType(type_id)},
1936 is_decimal(type_id) ? OutputType(FirstType) : OutputType(ty), exec, State::Init));
1937 }
1938 AddNullExec(func.get());
1939 return func;
1940 }
1941
1942 // Like MakeUnaryArithmeticFunction, but for signed arithmetic ops that need to run
1943 // only on non-null output.
1944 template <typename Op>
1945 std::shared_ptr<ScalarFunction> MakeUnarySignedArithmeticFunctionNotNull(
1946 std::string name, const FunctionDoc* doc) {
1947 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
1948 for (const auto& ty : NumericTypes()) {
1949 if (!arrow::is_unsigned_integer(ty->id())) {
1950 auto exec = ArithmeticExecFromOp<ScalarUnaryNotNull, Op>(ty);
1951 DCHECK_OK(func->AddKernel({ty}, ty, exec));
1952 }
1953 }
1954 AddNullExec(func.get());
1955 return func;
1956 }
1957
1958 template <typename Op>
1959 std::shared_ptr<ScalarFunction> MakeBitWiseFunctionNotNull(std::string name,
1960 const FunctionDoc* doc) {
1961 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
1962 for (const auto& ty : IntTypes()) {
1963 auto exec = TypeAgnosticBitWiseExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
1964 DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
1965 }
1966 AddNullExec(func.get());
1967 return func;
1968 }
1969
1970 template <typename Op>
1971 std::shared_ptr<ScalarFunction> MakeShiftFunctionNotNull(std::string name,
1972 const FunctionDoc* doc) {
1973 auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
1974 for (const auto& ty : IntTypes()) {
1975 auto exec = ShiftExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
1976 DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
1977 }
1978 AddNullExec(func.get());
1979 return func;
1980 }
1981
1982 template <typename Op, typename FunctionImpl = ArithmeticFloatingPointFunction>
1983 std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPoint(
1984 std::string name, const FunctionDoc* doc) {
1985 auto func = std::make_shared<FunctionImpl>(name, Arity::Unary(), doc);
1986 for (const auto& ty : FloatingPointTypes()) {
1987 auto exec = GenerateArithmeticFloatingPoint<ScalarUnary, Op>(ty);
1988 DCHECK_OK(func->AddKernel({ty}, ty, exec));
1989 }
1990 AddNullExec(func.get());
1991 return func;
1992 }
1993
1994 template <typename Op>
1995 std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPointNotNull(
1996 std::string name, const FunctionDoc* doc) {
1997 auto func =
1998 std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Unary(), doc);
1999 for (const auto& ty : FloatingPointTypes()) {
2000 auto exec = GenerateArithmeticFloatingPoint<ScalarUnaryNotNull, Op>(ty);
2001 DCHECK_OK(func->AddKernel({ty}, ty, exec));
2002 }
2003 AddNullExec(func.get());
2004 return func;
2005 }
2006
2007 template <typename Op>
2008 std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPoint(
2009 std::string name, const FunctionDoc* doc) {
2010 auto func =
2011 std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Binary(), doc);
2012 for (const auto& ty : FloatingPointTypes()) {
2013 auto exec = GenerateArithmeticFloatingPoint<ScalarBinaryEqualTypes, Op>(ty);
2014 DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
2015 }
2016 AddNullExec(func.get());
2017 return func;
2018 }
2019
2020 template <typename Op>
2021 std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPointNotNull(
2022 std::string name, const FunctionDoc* doc) {
2023 auto func =
2024 std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Binary(), doc);
2025 for (const auto& ty : FloatingPointTypes()) {
2026 auto output = is_integer(ty->id()) ? float64() : ty;
2027 auto exec = GenerateArithmeticFloatingPoint<ScalarBinaryNotNullEqualTypes, Op>(ty);
2028 DCHECK_OK(func->AddKernel({ty, ty}, output, exec));
2029 }
2030 AddNullExec(func.get());
2031 return func;
2032 }
2033
2034 const FunctionDoc absolute_value_doc{
2035 "Calculate the absolute value of the argument element-wise",
2036 ("Results will wrap around on integer overflow.\n"
2037 "Use function \"abs_checked\" if you want overflow\n"
2038 "to return an error."),
2039 {"x"}};
2040
2041 const FunctionDoc absolute_value_checked_doc{
2042 "Calculate the absolute value of the argument element-wise",
2043 ("This function returns an error on overflow. For a variant that\n"
2044 "doesn't fail on overflow, use function \"abs\"."),
2045 {"x"}};
2046
2047 const FunctionDoc add_doc{"Add the arguments element-wise",
2048 ("Results will wrap around on integer overflow.\n"
2049 "Use function \"add_checked\" if you want overflow\n"
2050 "to return an error."),
2051 {"x", "y"}};
2052
2053 const FunctionDoc add_checked_doc{
2054 "Add the arguments element-wise",
2055 ("This function returns an error on overflow. For a variant that\n"
2056 "doesn't fail on overflow, use function \"add\"."),
2057 {"x", "y"}};
2058
2059 const FunctionDoc sub_doc{"Subtract the arguments element-wise",
2060 ("Results will wrap around on integer overflow.\n"
2061 "Use function \"subtract_checked\" if you want overflow\n"
2062 "to return an error."),
2063 {"x", "y"}};
2064
2065 const FunctionDoc sub_checked_doc{
2066 "Subtract the arguments element-wise",
2067 ("This function returns an error on overflow. For a variant that\n"
2068 "doesn't fail on overflow, use function \"subtract\"."),
2069 {"x", "y"}};
2070
2071 const FunctionDoc mul_doc{"Multiply the arguments element-wise",
2072 ("Results will wrap around on integer overflow.\n"
2073 "Use function \"multiply_checked\" if you want overflow\n"
2074 "to return an error."),
2075 {"x", "y"}};
2076
2077 const FunctionDoc mul_checked_doc{
2078 "Multiply the arguments element-wise",
2079 ("This function returns an error on overflow. For a variant that\n"
2080 "doesn't fail on overflow, use function \"multiply\"."),
2081 {"x", "y"}};
2082
2083 const FunctionDoc div_doc{
2084 "Divide the arguments element-wise",
2085 ("Integer division by zero returns an error. However, integer overflow\n"
2086 "wraps around, and floating-point division by zero returns an infinite.\n"
2087 "Use function \"divide_checked\" if you want to get an error\n"
2088 "in all the aforementioned cases."),
2089 {"dividend", "divisor"}};
2090
2091 const FunctionDoc div_checked_doc{
2092 "Divide the arguments element-wise",
2093 ("An error is returned when trying to divide by zero, or when\n"
2094 "integer overflow is encountered."),
2095 {"dividend", "divisor"}};
2096
2097 const FunctionDoc negate_doc{"Negate the argument element-wise",
2098 ("Results will wrap around on integer overflow.\n"
2099 "Use function \"negate_checked\" if you want overflow\n"
2100 "to return an error."),
2101 {"x"}};
2102
2103 const FunctionDoc negate_checked_doc{
2104 "Negate the arguments element-wise",
2105 ("This function returns an error on overflow. For a variant that\n"
2106 "doesn't fail on overflow, use function \"negate\"."),
2107 {"x"}};
2108
2109 const FunctionDoc pow_doc{
2110 "Raise arguments to power element-wise",
2111 ("Integer to negative integer power returns an error. However, integer overflow\n"
2112 "wraps around. If either base or exponent is null the result will be null."),
2113 {"base", "exponent"}};
2114
2115 const FunctionDoc pow_checked_doc{
2116 "Raise arguments to power element-wise",
2117 ("An error is returned when integer to negative integer power is encountered,\n"
2118 "or integer overflow is encountered."),
2119 {"base", "exponent"}};
2120
2121 const FunctionDoc sign_doc{
2122 "Get the signedness of the arguments element-wise",
2123 ("Output is any of (-1,1) for nonzero inputs and 0 for zero input.\n"
2124 "NaN values return NaN. Integral values return signedness as Int8 and\n"
2125 "floating-point values return it with the same type as the input values."),
2126 {"x"}};
2127
2128 const FunctionDoc bit_wise_not_doc{
2129 "Bit-wise negate the arguments element-wise", "Null values return null.", {"x"}};
2130
2131 const FunctionDoc bit_wise_and_doc{
2132 "Bit-wise AND the arguments element-wise", "Null values return null.", {"x", "y"}};
2133
2134 const FunctionDoc bit_wise_or_doc{
2135 "Bit-wise OR the arguments element-wise", "Null values return null.", {"x", "y"}};
2136
2137 const FunctionDoc bit_wise_xor_doc{
2138 "Bit-wise XOR the arguments element-wise", "Null values return null.", {"x", "y"}};
2139
2140 const FunctionDoc shift_left_doc{
2141 "Left shift `x` by `y`",
2142 ("This function will return `x` if `y` (the amount to shift by) is: "
2143 "(1) negative or (2) greater than or equal to the precision of `x`.\n"
2144 "The shift operates as if on the two's complement representation of the number. "
2145 "In other words, this is equivalent to multiplying `x` by 2 to the power `y`, "
2146 "even if overflow occurs.\n"
2147 "Use function \"shift_left_checked\" if you want an invalid shift amount to "
2148 "return an error."),
2149 {"x", "y"}};
2150
2151 const FunctionDoc shift_left_checked_doc{
2152 "Left shift `x` by `y` with invalid shift check",
2153 ("This function will raise an error if `y` (the amount to shift by) is: "
2154 "(1) negative or (2) greater than or equal to the precision of `x`. "
2155 "The shift operates as if on the two's complement representation of the number. "
2156 "In other words, this is equivalent to multiplying `x` by 2 to the power `y`, "
2157 "even if overflow occurs.\n"
2158 "See \"shift_left\" for a variant that doesn't fail for an invalid shift amount."),
2159 {"x", "y"}};
2160
2161 const FunctionDoc shift_right_doc{
2162 "Right shift `x` by `y`",
2163 ("Perform a logical shift for unsigned `x` and an arithmetic shift for signed `x`.\n"
2164 "This function will return `x` if `y` (the amount to shift by) is: "
2165 "(1) negative or (2) greater than or equal to the precision of `x`.\n"
2166 "Use function \"shift_right_checked\" if you want an invalid shift amount to return "
2167 "an error."),
2168 {"x", "y"}};
2169
2170 const FunctionDoc shift_right_checked_doc{
2171 "Right shift `x` by `y` with invalid shift check",
2172 ("Perform a logical shift for unsigned `x` and an arithmetic shift for signed `x`.\n"
2173 "This function will raise an error if `y` (the amount to shift by) is: "
2174 "(1) negative or (2) greater than or equal to the precision of `x`.\n"
2175 "See \"shift_right\" for a variant that doesn't fail for an invalid shift amount"),
2176 {"x", "y"}};
2177
2178 const FunctionDoc sin_doc{"Compute the sine of the elements argument-wise",
2179 ("Integer arguments return double values. "
2180 "This function returns NaN on values outside its domain. "
2181 "To raise an error instead, see \"sin_checked\"."),
2182 {"x"}};
2183
2184 const FunctionDoc sin_checked_doc{
2185 "Compute the sine of the elements argument-wise",
2186 ("Integer arguments return double values. "
2187 "This function raises an error on values outside its domain. "
2188 "To return NaN instead, see \"sin\"."),
2189 {"x"}};
2190
2191 const FunctionDoc cos_doc{"Compute the cosine of the elements argument-wise",
2192 ("Integer arguments return double values. "
2193 "This function returns NaN on values outside its domain. "
2194 "To raise an error instead, see \"cos_checked\"."),
2195 {"x"}};
2196
2197 const FunctionDoc cos_checked_doc{
2198 "Compute the cosine of the elements argument-wise",
2199 ("Integer arguments return double values. "
2200 "This function raises an error on values outside its domain. "
2201 "To return NaN instead, see \"cos\"."),
2202 {"x"}};
2203
2204 const FunctionDoc tan_doc{"Compute the tangent of the elements argument-wise",
2205 ("Integer arguments return double values. "
2206 "This function returns NaN on values outside its domain. "
2207 "To raise an error instead, see \"tan_checked\"."),
2208 {"x"}};
2209
2210 const FunctionDoc tan_checked_doc{
2211 "Compute the tangent of the elements argument-wise",
2212 ("Integer arguments return double values. "
2213 "This function raises an error on values outside its domain. "
2214 "To return NaN instead, see \"tan\"."),
2215 {"x"}};
2216
2217 const FunctionDoc asin_doc{"Compute the inverse sine of the elements argument-wise",
2218 ("Integer arguments return double values. "
2219 "This function returns NaN on values outside its domain. "
2220 "To raise an error instead, see \"asin_checked\"."),
2221 {"x"}};
2222
2223 const FunctionDoc asin_checked_doc{
2224 "Compute the inverse sine of the elements argument-wise",
2225 ("Integer arguments return double values. "
2226 "This function raises an error on values outside its domain. "
2227 "To return NaN instead, see \"asin\"."),
2228 {"x"}};
2229
2230 const FunctionDoc acos_doc{"Compute the inverse cosine of the elements argument-wise",
2231 ("Integer arguments return double values. "
2232 "This function returns NaN on values outside its domain. "
2233 "To raise an error instead, see \"acos_checked\"."),
2234 {"x"}};
2235
2236 const FunctionDoc acos_checked_doc{
2237 "Compute the inverse cosine of the elements argument-wise",
2238 ("Integer arguments return double values. "
2239 "This function raises an error on values outside its domain. "
2240 "To return NaN instead, see \"acos\"."),
2241 {"x"}};
2242
2243 const FunctionDoc atan_doc{"Compute the principal value of the inverse tangent",
2244 "Integer arguments return double values.",
2245 {"x"}};
2246
2247 const FunctionDoc atan2_doc{
2248 "Compute the inverse tangent using argument signs to determine the quadrant",
2249 "Integer arguments return double values.",
2250 {"y", "x"}};
2251
2252 const FunctionDoc ln_doc{
2253 "Compute natural log of arguments element-wise",
2254 ("Non-positive values return -inf or NaN. Null values return null.\n"
2255 "Use function \"ln_checked\" if you want non-positive values to raise an error."),
2256 {"x"}};
2257
2258 const FunctionDoc ln_checked_doc{
2259 "Compute natural log of arguments element-wise",
2260 ("Non-positive values return -inf or NaN. Null values return null.\n"
2261 "Use function \"ln\" if you want non-positive values to return "
2262 "-inf or NaN."),
2263 {"x"}};
2264
2265 const FunctionDoc log10_doc{
2266 "Compute log base 10 of arguments element-wise",
2267 ("Non-positive values return -inf or NaN. Null values return null.\n"
2268 "Use function \"log10_checked\" if you want non-positive values to raise an error."),
2269 {"x"}};
2270
2271 const FunctionDoc log10_checked_doc{
2272 "Compute log base 10 of arguments element-wise",
2273 ("Non-positive values return -inf or NaN. Null values return null.\n"
2274 "Use function \"log10\" if you want non-positive values to return "
2275 "-inf or NaN."),
2276 {"x"}};
2277
2278 const FunctionDoc log2_doc{
2279 "Compute log base 2 of arguments element-wise",
2280 ("Non-positive values return -inf or NaN. Null values return null.\n"
2281 "Use function \"log2_checked\" if you want non-positive values to raise an error."),
2282 {"x"}};
2283
2284 const FunctionDoc log2_checked_doc{
2285 "Compute log base 2 of arguments element-wise",
2286 ("Non-positive values return -inf or NaN. Null values return null.\n"
2287 "Use function \"log2\" if you want non-positive values to return "
2288 "-inf or NaN."),
2289 {"x"}};
2290
2291 const FunctionDoc log1p_doc{
2292 "Compute natural log of (1+x) element-wise",
2293 ("Values <= -1 return -inf or NaN. Null values return null.\n"
2294 "This function may be more precise than log(1 + x) for x close to zero."
2295 "Use function \"log1p_checked\" if you want non-positive values to raise an error."),
2296 {"x"}};
2297
2298 const FunctionDoc log1p_checked_doc{
2299 "Compute natural log of (1+x) element-wise",
2300 ("Values <= -1 return -inf or NaN. Null values return null.\n"
2301 "This function may be more precise than log(1 + x) for x close to zero."
2302 "Use function \"log1p\" if you want non-positive values to return "
2303 "-inf or NaN."),
2304 {"x"}};
2305
2306 const FunctionDoc logb_doc{
2307 "Compute log of x to base b of arguments element-wise",
2308 ("Values <= 0 return -inf or NaN. Null values return null.\n"
2309 "Use function \"logb_checked\" if you want non-positive values to raise an error."),
2310 {"x", "b"}};
2311
2312 const FunctionDoc logb_checked_doc{
2313 "Compute log of x to base b of arguments element-wise",
2314 ("Values <= 0 return -inf or NaN. Null values return null.\n"
2315 "Use function \"logb\" if you want non-positive values to return "
2316 "-inf or NaN."),
2317 {"x", "b"}};
2318
2319 const FunctionDoc floor_doc{
2320 "Round down to the nearest integer",
2321 ("Calculate the nearest integer less than or equal in magnitude to the "
2322 "argument element-wise"),
2323 {"x"}};
2324
2325 const FunctionDoc ceil_doc{
2326 "Round up to the nearest integer",
2327 ("Calculate the nearest integer greater than or equal in magnitude to the "
2328 "argument element-wise"),
2329 {"x"}};
2330
2331 const FunctionDoc trunc_doc{
2332 "Get the integral part without fractional digits",
2333 ("Calculate the nearest integer not greater in magnitude than to the "
2334 "argument element-wise."),
2335 {"x"}};
2336
2337 const FunctionDoc round_doc{
2338 "Round to a given precision",
2339 ("Options are used to control the number of digits and rounding mode.\n"
2340 "Default behavior is to round to the nearest integer and use half-to-even "
2341 "rule to break ties."),
2342 {"x"},
2343 "RoundOptions"};
2344
2345 const FunctionDoc round_to_multiple_doc{
2346 "Round to a given multiple",
2347 ("Options are used to control the rounding multiple and rounding mode.\n"
2348 "Default behavior is to round to the nearest integer and use half-to-even "
2349 "rule to break ties."),
2350 {"x"},
2351 "RoundToMultipleOptions"};
2352 } // namespace
2353
2354 void RegisterScalarArithmetic(FunctionRegistry* registry) {
2355 // ----------------------------------------------------------------------
2356 auto absolute_value =
2357 MakeUnaryArithmeticFunction<AbsoluteValue>("abs", &absolute_value_doc);
2358 DCHECK_OK(registry->AddFunction(std::move(absolute_value)));
2359
2360 // ----------------------------------------------------------------------
2361 auto absolute_value_checked = MakeUnaryArithmeticFunctionNotNull<AbsoluteValueChecked>(
2362 "abs_checked", &absolute_value_checked_doc);
2363 DCHECK_OK(registry->AddFunction(std::move(absolute_value_checked)));
2364
2365 // ----------------------------------------------------------------------
2366 auto add = MakeArithmeticFunction<Add>("add", &add_doc);
2367 AddDecimalBinaryKernels<Add>("add", &add);
2368 DCHECK_OK(registry->AddFunction(std::move(add)));
2369
2370 // ----------------------------------------------------------------------
2371 auto add_checked =
2372 MakeArithmeticFunctionNotNull<AddChecked>("add_checked", &add_checked_doc);
2373 AddDecimalBinaryKernels<AddChecked>("add_checked", &add_checked);
2374 DCHECK_OK(registry->AddFunction(std::move(add_checked)));
2375
2376 // ----------------------------------------------------------------------
2377 auto subtract = MakeArithmeticFunction<Subtract>("subtract", &sub_doc);
2378 AddDecimalBinaryKernels<Subtract>("subtract", &subtract);
2379
2380 // Add subtract(timestamp, timestamp) -> duration
2381 for (auto unit : TimeUnit::values()) {
2382 InputType in_type(match::TimestampTypeUnit(unit));
2383 auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Subtract>(Type::TIMESTAMP);
2384 DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
2385 }
2386
2387 DCHECK_OK(registry->AddFunction(std::move(subtract)));
2388
2389 // ----------------------------------------------------------------------
2390 auto subtract_checked = MakeArithmeticFunctionNotNull<SubtractChecked>(
2391 "subtract_checked", &sub_checked_doc);
2392 AddDecimalBinaryKernels<SubtractChecked>("subtract_checked", &subtract_checked);
2393 DCHECK_OK(registry->AddFunction(std::move(subtract_checked)));
2394
2395 // ----------------------------------------------------------------------
2396 auto multiply = MakeArithmeticFunction<Multiply>("multiply", &mul_doc);
2397 AddDecimalBinaryKernels<Multiply>("multiply", &multiply);
2398 DCHECK_OK(registry->AddFunction(std::move(multiply)));
2399
2400 // ----------------------------------------------------------------------
2401 auto multiply_checked = MakeArithmeticFunctionNotNull<MultiplyChecked>(
2402 "multiply_checked", &mul_checked_doc);
2403 AddDecimalBinaryKernels<MultiplyChecked>("multiply_checked", &multiply_checked);
2404 DCHECK_OK(registry->AddFunction(std::move(multiply_checked)));
2405
2406 // ----------------------------------------------------------------------
2407 auto divide = MakeArithmeticFunctionNotNull<Divide>("divide", &div_doc);
2408 AddDecimalBinaryKernels<Divide>("divide", &divide);
2409 DCHECK_OK(registry->AddFunction(std::move(divide)));
2410
2411 // ----------------------------------------------------------------------
2412 auto divide_checked =
2413 MakeArithmeticFunctionNotNull<DivideChecked>("divide_checked", &div_checked_doc);
2414 AddDecimalBinaryKernels<DivideChecked>("divide_checked", &divide_checked);
2415 DCHECK_OK(registry->AddFunction(std::move(divide_checked)));
2416
2417 // ----------------------------------------------------------------------
2418 auto negate = MakeUnaryArithmeticFunction<Negate>("negate", &negate_doc);
2419 DCHECK_OK(registry->AddFunction(std::move(negate)));
2420
2421 // ----------------------------------------------------------------------
2422 auto negate_checked = MakeUnarySignedArithmeticFunctionNotNull<NegateChecked>(
2423 "negate_checked", &negate_checked_doc);
2424 DCHECK_OK(registry->AddFunction(std::move(negate_checked)));
2425
2426 // ----------------------------------------------------------------------
2427 auto power = MakeArithmeticFunction<Power>("power", &pow_doc);
2428 DCHECK_OK(registry->AddFunction(std::move(power)));
2429
2430 // ----------------------------------------------------------------------
2431 auto power_checked =
2432 MakeArithmeticFunctionNotNull<PowerChecked>("power_checked", &pow_checked_doc);
2433 DCHECK_OK(registry->AddFunction(std::move(power_checked)));
2434
2435 // ----------------------------------------------------------------------
2436 auto sign =
2437 MakeUnaryArithmeticFunctionWithFixedIntOutType<Sign, Int8Type>("sign", &sign_doc);
2438 DCHECK_OK(registry->AddFunction(std::move(sign)));
2439
2440 // ----------------------------------------------------------------------
2441 // Bitwise functions
2442 {
2443 auto bit_wise_not = std::make_shared<ArithmeticFunction>(
2444 "bit_wise_not", Arity::Unary(), &bit_wise_not_doc);
2445 for (const auto& ty : IntTypes()) {
2446 auto exec = TypeAgnosticBitWiseExecFromOp<ScalarUnaryNotNull, BitWiseNot>(ty);
2447 DCHECK_OK(bit_wise_not->AddKernel({ty}, ty, exec));
2448 }
2449 AddNullExec(bit_wise_not.get());
2450 DCHECK_OK(registry->AddFunction(std::move(bit_wise_not)));
2451 }
2452
2453 auto bit_wise_and =
2454 MakeBitWiseFunctionNotNull<BitWiseAnd>("bit_wise_and", &bit_wise_and_doc);
2455 DCHECK_OK(registry->AddFunction(std::move(bit_wise_and)));
2456
2457 auto bit_wise_or =
2458 MakeBitWiseFunctionNotNull<BitWiseOr>("bit_wise_or", &bit_wise_or_doc);
2459 DCHECK_OK(registry->AddFunction(std::move(bit_wise_or)));
2460
2461 auto bit_wise_xor =
2462 MakeBitWiseFunctionNotNull<BitWiseXor>("bit_wise_xor", &bit_wise_xor_doc);
2463 DCHECK_OK(registry->AddFunction(std::move(bit_wise_xor)));
2464
2465 auto shift_left = MakeShiftFunctionNotNull<ShiftLeft>("shift_left", &shift_left_doc);
2466 DCHECK_OK(registry->AddFunction(std::move(shift_left)));
2467
2468 auto shift_left_checked = MakeShiftFunctionNotNull<ShiftLeftChecked>(
2469 "shift_left_checked", &shift_left_checked_doc);
2470 DCHECK_OK(registry->AddFunction(std::move(shift_left_checked)));
2471
2472 auto shift_right =
2473 MakeShiftFunctionNotNull<ShiftRight>("shift_right", &shift_right_doc);
2474 DCHECK_OK(registry->AddFunction(std::move(shift_right)));
2475
2476 auto shift_right_checked = MakeShiftFunctionNotNull<ShiftRightChecked>(
2477 "shift_right_checked", &shift_right_checked_doc);
2478 DCHECK_OK(registry->AddFunction(std::move(shift_right_checked)));
2479
2480 // ----------------------------------------------------------------------
2481 // Trig functions
2482 auto sin = MakeUnaryArithmeticFunctionFloatingPoint<Sin>("sin", &sin_doc);
2483 DCHECK_OK(registry->AddFunction(std::move(sin)));
2484
2485 auto sin_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<SinChecked>(
2486 "sin_checked", &sin_checked_doc);
2487 DCHECK_OK(registry->AddFunction(std::move(sin_checked)));
2488
2489 auto cos = MakeUnaryArithmeticFunctionFloatingPoint<Cos>("cos", &cos_doc);
2490 DCHECK_OK(registry->AddFunction(std::move(cos)));
2491
2492 auto cos_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<CosChecked>(
2493 "cos_checked", &cos_checked_doc);
2494 DCHECK_OK(registry->AddFunction(std::move(cos_checked)));
2495
2496 auto tan = MakeUnaryArithmeticFunctionFloatingPoint<Tan>("tan", &tan_doc);
2497 DCHECK_OK(registry->AddFunction(std::move(tan)));
2498
2499 auto tan_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<TanChecked>(
2500 "tan_checked", &tan_checked_doc);
2501 DCHECK_OK(registry->AddFunction(std::move(tan_checked)));
2502
2503 auto asin = MakeUnaryArithmeticFunctionFloatingPoint<Asin>("asin", &asin_doc);
2504 DCHECK_OK(registry->AddFunction(std::move(asin)));
2505
2506 auto asin_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<AsinChecked>(
2507 "asin_checked", &asin_checked_doc);
2508 DCHECK_OK(registry->AddFunction(std::move(asin_checked)));
2509
2510 auto acos = MakeUnaryArithmeticFunctionFloatingPoint<Acos>("acos", &acos_doc);
2511 DCHECK_OK(registry->AddFunction(std::move(acos)));
2512
2513 auto acos_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<AcosChecked>(
2514 "acos_checked", &acos_checked_doc);
2515 DCHECK_OK(registry->AddFunction(std::move(acos_checked)));
2516
2517 auto atan = MakeUnaryArithmeticFunctionFloatingPoint<Atan>("atan", &atan_doc);
2518 DCHECK_OK(registry->AddFunction(std::move(atan)));
2519
2520 auto atan2 = MakeArithmeticFunctionFloatingPoint<Atan2>("atan2", &atan2_doc);
2521 DCHECK_OK(registry->AddFunction(std::move(atan2)));
2522
2523 // ----------------------------------------------------------------------
2524 // Logarithms
2525 auto ln = MakeUnaryArithmeticFunctionFloatingPoint<LogNatural>("ln", &ln_doc);
2526 DCHECK_OK(registry->AddFunction(std::move(ln)));
2527
2528 auto ln_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<LogNaturalChecked>(
2529 "ln_checked", &ln_checked_doc);
2530 DCHECK_OK(registry->AddFunction(std::move(ln_checked)));
2531
2532 auto log10 = MakeUnaryArithmeticFunctionFloatingPoint<Log10>("log10", &log10_doc);
2533 DCHECK_OK(registry->AddFunction(std::move(log10)));
2534
2535 auto log10_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<Log10Checked>(
2536 "log10_checked", &log10_checked_doc);
2537 DCHECK_OK(registry->AddFunction(std::move(log10_checked)));
2538
2539 auto log2 = MakeUnaryArithmeticFunctionFloatingPoint<Log2>("log2", &log2_doc);
2540 DCHECK_OK(registry->AddFunction(std::move(log2)));
2541
2542 auto log2_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<Log2Checked>(
2543 "log2_checked", &log2_checked_doc);
2544 DCHECK_OK(registry->AddFunction(std::move(log2_checked)));
2545
2546 auto log1p = MakeUnaryArithmeticFunctionFloatingPoint<Log1p>("log1p", &log1p_doc);
2547 DCHECK_OK(registry->AddFunction(std::move(log1p)));
2548
2549 auto log1p_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<Log1pChecked>(
2550 "log1p_checked", &log1p_checked_doc);
2551 DCHECK_OK(registry->AddFunction(std::move(log1p_checked)));
2552
2553 auto logb = MakeArithmeticFunctionFloatingPoint<Logb>("logb", &logb_doc);
2554 DCHECK_OK(registry->AddFunction(std::move(logb)));
2555
2556 auto logb_checked = MakeArithmeticFunctionFloatingPointNotNull<LogbChecked>(
2557 "logb_checked", &logb_checked_doc);
2558 DCHECK_OK(registry->AddFunction(std::move(logb_checked)));
2559
2560 // ----------------------------------------------------------------------
2561 // Rounding functions
2562 auto floor =
2563 MakeUnaryArithmeticFunctionFloatingPoint<Floor,
2564 ArithmeticIntegerToFloatingPointFunction>(
2565 "floor", &floor_doc);
2566 DCHECK_OK(floor->AddKernel(
2567 {InputType(Type::DECIMAL128)}, OutputType(FirstType),
2568 FixedRoundDecimalExec<Decimal128Type, RoundMode::DOWN, /*ndigits=*/0>));
2569 DCHECK_OK(floor->AddKernel(
2570 {InputType(Type::DECIMAL256)}, OutputType(FirstType),
2571 FixedRoundDecimalExec<Decimal256Type, RoundMode::DOWN, /*ndigits=*/0>));
2572 DCHECK_OK(registry->AddFunction(std::move(floor)));
2573
2574 auto ceil =
2575 MakeUnaryArithmeticFunctionFloatingPoint<Ceil,
2576 ArithmeticIntegerToFloatingPointFunction>(
2577 "ceil", &ceil_doc);
2578 DCHECK_OK(ceil->AddKernel(
2579 {InputType(Type::DECIMAL128)}, OutputType(FirstType),
2580 FixedRoundDecimalExec<Decimal128Type, RoundMode::UP, /*ndigits=*/0>));
2581 DCHECK_OK(ceil->AddKernel(
2582 {InputType(Type::DECIMAL256)}, OutputType(FirstType),
2583 FixedRoundDecimalExec<Decimal256Type, RoundMode::UP, /*ndigits=*/0>));
2584 DCHECK_OK(registry->AddFunction(std::move(ceil)));
2585
2586 auto trunc =
2587 MakeUnaryArithmeticFunctionFloatingPoint<Trunc,
2588 ArithmeticIntegerToFloatingPointFunction>(
2589 "trunc", &trunc_doc);
2590 DCHECK_OK(trunc->AddKernel(
2591 {InputType(Type::DECIMAL128)}, OutputType(FirstType),
2592 FixedRoundDecimalExec<Decimal128Type, RoundMode::TOWARDS_ZERO, /*ndigits=*/0>));
2593 DCHECK_OK(trunc->AddKernel(
2594 {InputType(Type::DECIMAL256)}, OutputType(FirstType),
2595 FixedRoundDecimalExec<Decimal256Type, RoundMode::TOWARDS_ZERO, /*ndigits=*/0>));
2596 DCHECK_OK(registry->AddFunction(std::move(trunc)));
2597
2598 auto round = MakeUnaryRoundFunction<Round, RoundOptions>("round", &round_doc);
2599 DCHECK_OK(registry->AddFunction(std::move(round)));
2600
2601 auto round_to_multiple =
2602 MakeUnaryRoundFunction<RoundToMultiple, RoundToMultipleOptions>(
2603 "round_to_multiple", &round_to_multiple_doc);
2604 DCHECK_OK(registry->AddFunction(std::move(round_to_multiple)));
2605 }
2606
2607 } // namespace internal
2608 } // namespace compute
2609 } // namespace arrow