]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/r/src/compute.cpp
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / r / src / compute.cpp
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "./arrow_types.h"
19
20 #if defined(ARROW_R_WITH_ARROW)
21
22 #include <arrow/compute/api.h>
23 #include <arrow/record_batch.h>
24 #include <arrow/table.h>
25
26 std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options);
27
28 arrow::compute::ExecContext* gc_context() {
29 static arrow::compute::ExecContext context(gc_memory_pool());
30 return &context;
31 }
32
33 // [[arrow::export]]
34 std::shared_ptr<arrow::RecordBatch> RecordBatch__cast(
35 const std::shared_ptr<arrow::RecordBatch>& batch,
36 const std::shared_ptr<arrow::Schema>& schema, cpp11::list options) {
37 auto opts = make_cast_options(options);
38 auto nc = batch->num_columns();
39
40 arrow::ArrayVector columns(nc);
41 for (int i = 0; i < nc; i++) {
42 columns[i] = ValueOrStop(
43 arrow::compute::Cast(*batch->column(i), schema->field(i)->type(), *opts));
44 }
45
46 return arrow::RecordBatch::Make(schema, batch->num_rows(), std::move(columns));
47 }
48
49 // [[arrow::export]]
50 std::shared_ptr<arrow::Table> Table__cast(const std::shared_ptr<arrow::Table>& table,
51 const std::shared_ptr<arrow::Schema>& schema,
52 cpp11::list options) {
53 auto opts = make_cast_options(options);
54 auto nc = table->num_columns();
55
56 using ColumnVector = std::vector<std::shared_ptr<arrow::ChunkedArray>>;
57 ColumnVector columns(nc);
58 for (int i = 0; i < nc; i++) {
59 arrow::Datum value(table->column(i));
60 arrow::Datum out =
61 ValueOrStop(arrow::compute::Cast(value, schema->field(i)->type(), *opts));
62 columns[i] = out.chunked_array();
63 }
64 return arrow::Table::Make(schema, std::move(columns), table->num_rows());
65 }
66
67 template <typename T>
68 std::shared_ptr<T> MaybeUnbox(const char* class_name, SEXP x) {
69 if (Rf_inherits(x, "ArrowObject") && Rf_inherits(x, class_name)) {
70 return cpp11::as_cpp<std::shared_ptr<T>>(x);
71 }
72 return nullptr;
73 }
74
75 namespace cpp11 {
76
77 template <>
78 arrow::Datum as_cpp<arrow::Datum>(SEXP x) {
79 if (auto array = MaybeUnbox<arrow::Array>("Array", x)) {
80 return array;
81 }
82
83 if (auto chunked_array = MaybeUnbox<arrow::ChunkedArray>("ChunkedArray", x)) {
84 return chunked_array;
85 }
86
87 if (auto batch = MaybeUnbox<arrow::RecordBatch>("RecordBatch", x)) {
88 return batch;
89 }
90
91 if (auto table = MaybeUnbox<arrow::Table>("Table", x)) {
92 return table;
93 }
94
95 if (auto scalar = MaybeUnbox<arrow::Scalar>("Scalar", x)) {
96 return scalar;
97 }
98
99 // This assumes that R objects have already been converted to Arrow objects;
100 // that seems right but should we do the wrapping here too/instead?
101 cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x)));
102 }
103 } // namespace cpp11
104
105 SEXP from_datum(arrow::Datum datum) {
106 switch (datum.kind()) {
107 case arrow::Datum::SCALAR:
108 return cpp11::to_r6(datum.scalar());
109
110 case arrow::Datum::ARRAY:
111 return cpp11::to_r6(datum.make_array());
112
113 case arrow::Datum::CHUNKED_ARRAY:
114 return cpp11::to_r6(datum.chunked_array());
115
116 case arrow::Datum::RECORD_BATCH:
117 return cpp11::to_r6(datum.record_batch());
118
119 case arrow::Datum::TABLE:
120 return cpp11::to_r6(datum.table());
121
122 default:
123 break;
124 }
125
126 cpp11::stop("from_datum: Not implemented for Datum %s", datum.ToString().c_str());
127 }
128
129 std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
130 std::string func_name, cpp11::list options) {
131 if (func_name == "filter") {
132 using Options = arrow::compute::FilterOptions;
133 auto out = std::make_shared<Options>(Options::Defaults());
134 SEXP keep_na = options["keep_na"];
135 if (!Rf_isNull(keep_na) && cpp11::as_cpp<bool>(keep_na)) {
136 out->null_selection_behavior = Options::EMIT_NULL;
137 }
138 return out;
139 }
140
141 if (func_name == "take") {
142 using Options = arrow::compute::TakeOptions;
143 auto out = std::make_shared<Options>(Options::Defaults());
144 return out;
145 }
146
147 if (func_name == "array_sort_indices") {
148 using Order = arrow::compute::SortOrder;
149 using Options = arrow::compute::ArraySortOptions;
150 // false means descending, true means ascending
151 auto order = cpp11::as_cpp<bool>(options["order"]);
152 auto out =
153 std::make_shared<Options>(Options(order ? Order::Descending : Order::Ascending));
154 return out;
155 }
156
157 if (func_name == "sort_indices") {
158 using Key = arrow::compute::SortKey;
159 using Order = arrow::compute::SortOrder;
160 using Options = arrow::compute::SortOptions;
161 auto names = cpp11::as_cpp<std::vector<std::string>>(options["names"]);
162 // false means descending, true means ascending
163 // cpp11 does not support bool here so use int
164 auto orders = cpp11::as_cpp<std::vector<int>>(options["orders"]);
165 std::vector<Key> keys;
166 for (size_t i = 0; i < names.size(); i++) {
167 keys.push_back(
168 Key(names[i], (orders[i] > 0) ? Order::Descending : Order::Ascending));
169 }
170 auto out = std::make_shared<Options>(Options(keys));
171 return out;
172 }
173
174 if (func_name == "all" || func_name == "hash_all" || func_name == "any" ||
175 func_name == "hash_any" || func_name == "approximate_median" ||
176 func_name == "hash_approximate_median" || func_name == "mean" ||
177 func_name == "hash_mean" || func_name == "min_max" || func_name == "hash_min_max" ||
178 func_name == "min" || func_name == "hash_min" || func_name == "max" ||
179 func_name == "hash_max" || func_name == "sum" || func_name == "hash_sum") {
180 using Options = arrow::compute::ScalarAggregateOptions;
181 auto out = std::make_shared<Options>(Options::Defaults());
182 if (!Rf_isNull(options["min_count"])) {
183 out->min_count = cpp11::as_cpp<int>(options["min_count"]);
184 }
185 if (!Rf_isNull(options["skip_nulls"])) {
186 out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
187 }
188 return out;
189 }
190
191 if (func_name == "tdigest" || func_name == "hash_tdigest") {
192 using Options = arrow::compute::TDigestOptions;
193 auto out = std::make_shared<Options>(Options::Defaults());
194 if (!Rf_isNull(options["q"])) {
195 out->q = cpp11::as_cpp<std::vector<double>>(options["q"]);
196 }
197 if (!Rf_isNull(options["skip_nulls"])) {
198 out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
199 }
200 return out;
201 }
202
203 if (func_name == "count") {
204 using Options = arrow::compute::CountOptions;
205 auto out = std::make_shared<Options>(Options::Defaults());
206 out->mode =
207 cpp11::as_cpp<bool>(options["na.rm"]) ? Options::ONLY_VALID : Options::ONLY_NULL;
208 return out;
209 }
210
211 if (func_name == "count_distinct" || func_name == "hash_count_distinct") {
212 using Options = arrow::compute::CountOptions;
213 auto out = std::make_shared<Options>(Options::Defaults());
214 out->mode =
215 cpp11::as_cpp<bool>(options["na.rm"]) ? Options::ONLY_VALID : Options::ALL;
216 return out;
217 }
218
219 if (func_name == "min_element_wise" || func_name == "max_element_wise") {
220 using Options = arrow::compute::ElementWiseAggregateOptions;
221 bool skip_nulls = true;
222 if (!Rf_isNull(options["skip_nulls"])) {
223 skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
224 }
225 return std::make_shared<Options>(skip_nulls);
226 }
227
228 if (func_name == "quantile") {
229 using Options = arrow::compute::QuantileOptions;
230 auto out = std::make_shared<Options>(Options::Defaults());
231 SEXP q = options["q"];
232 if (!Rf_isNull(q) && TYPEOF(q) == REALSXP) {
233 out->q = cpp11::as_cpp<std::vector<double>>(q);
234 }
235 SEXP interpolation = options["interpolation"];
236 if (!Rf_isNull(interpolation) && TYPEOF(interpolation) == INTSXP &&
237 XLENGTH(interpolation) == 1) {
238 out->interpolation =
239 cpp11::as_cpp<enum arrow::compute::QuantileOptions::Interpolation>(
240 interpolation);
241 }
242 if (!Rf_isNull(options["min_count"])) {
243 out->min_count = cpp11::as_cpp<int64_t>(options["min_count"]);
244 }
245 if (!Rf_isNull(options["skip_nulls"])) {
246 out->skip_nulls = cpp11::as_cpp<int64_t>(options["skip_nulls"]);
247 }
248 return out;
249 }
250
251 if (func_name == "is_in" || func_name == "index_in") {
252 using Options = arrow::compute::SetLookupOptions;
253 return std::make_shared<Options>(cpp11::as_cpp<arrow::Datum>(options["value_set"]),
254 cpp11::as_cpp<bool>(options["skip_nulls"]));
255 }
256
257 if (func_name == "index") {
258 using Options = arrow::compute::IndexOptions;
259 return std::make_shared<Options>(
260 cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(options["value"]));
261 }
262
263 if (func_name == "is_null") {
264 using Options = arrow::compute::NullOptions;
265 auto out = std::make_shared<Options>(Options::Defaults());
266 if (!Rf_isNull(options["nan_is_null"])) {
267 out->nan_is_null = cpp11::as_cpp<bool>(options["nan_is_null"]);
268 }
269 return out;
270 }
271
272 if (func_name == "dictionary_encode") {
273 using Options = arrow::compute::DictionaryEncodeOptions;
274 auto out = std::make_shared<Options>(Options::Defaults());
275 if (!Rf_isNull(options["null_encoding_behavior"])) {
276 out->null_encoding_behavior = cpp11::as_cpp<
277 enum arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior>(
278 options["null_encoding_behavior"]);
279 }
280 return out;
281 }
282
283 if (func_name == "cast") {
284 return make_cast_options(options);
285 }
286
287 if (func_name == "binary_join_element_wise") {
288 using Options = arrow::compute::JoinOptions;
289 auto out = std::make_shared<Options>(Options::Defaults());
290 if (!Rf_isNull(options["null_handling"])) {
291 out->null_handling =
292 cpp11::as_cpp<enum arrow::compute::JoinOptions::NullHandlingBehavior>(
293 options["null_handling"]);
294 }
295 if (!Rf_isNull(options["null_replacement"])) {
296 out->null_replacement = cpp11::as_cpp<std::string>(options["null_replacement"]);
297 }
298 return out;
299 }
300
301 if (func_name == "make_struct") {
302 using Options = arrow::compute::MakeStructOptions;
303 // TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options
304 return std::make_shared<Options>(
305 cpp11::as_cpp<std::vector<std::string>>(options["field_names"]));
306 }
307
308 if (func_name == "match_substring" || func_name == "match_substring_regex" ||
309 func_name == "find_substring" || func_name == "find_substring_regex" ||
310 func_name == "match_like" || func_name == "starts_with" ||
311 func_name == "ends_with" || func_name == "count_substring" ||
312 func_name == "count_substring_regex") {
313 using Options = arrow::compute::MatchSubstringOptions;
314 bool ignore_case = false;
315 if (!Rf_isNull(options["ignore_case"])) {
316 ignore_case = cpp11::as_cpp<bool>(options["ignore_case"]);
317 }
318 return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
319 ignore_case);
320 }
321
322 if (func_name == "replace_substring" || func_name == "replace_substring_regex") {
323 using Options = arrow::compute::ReplaceSubstringOptions;
324 int64_t max_replacements = -1;
325 if (!Rf_isNull(options["max_replacements"])) {
326 max_replacements = cpp11::as_cpp<int64_t>(options["max_replacements"]);
327 }
328 return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
329 cpp11::as_cpp<std::string>(options["replacement"]),
330 max_replacements);
331 }
332
333 if (func_name == "extract_regex") {
334 using Options = arrow::compute::ExtractRegexOptions;
335 return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]));
336 }
337
338 if (func_name == "day_of_week") {
339 using Options = arrow::compute::DayOfWeekOptions;
340 bool count_from_zero = false;
341 if (!Rf_isNull(options["count_from_zero"])) {
342 count_from_zero = cpp11::as_cpp<bool>(options["count_from_zero"]);
343 }
344 return std::make_shared<Options>(count_from_zero,
345 cpp11::as_cpp<uint32_t>(options["week_start"]));
346 }
347
348 if (func_name == "iso_week") {
349 return std::make_shared<arrow::compute::WeekOptions>(
350 arrow::compute::WeekOptions::ISODefaults());
351 }
352
353 if (func_name == "us_week") {
354 return std::make_shared<arrow::compute::WeekOptions>(
355 arrow::compute::WeekOptions::USDefaults());
356 }
357
358 if (func_name == "week") {
359 using Options = arrow::compute::WeekOptions;
360 bool week_starts_monday = true;
361 bool count_from_zero = false;
362 bool first_week_is_fully_in_year = false;
363 if (!Rf_isNull(options["week_starts_monday"])) {
364 week_starts_monday = cpp11::as_cpp<bool>(options["week_starts_monday"]);
365 }
366 if (!Rf_isNull(options["count_from_zero"])) {
367 count_from_zero = cpp11::as_cpp<bool>(options["count_from_zero"]);
368 }
369 if (!Rf_isNull(options["first_week_is_fully_in_year"])) {
370 count_from_zero = cpp11::as_cpp<bool>(options["first_week_is_fully_in_year"]);
371 }
372 return std::make_shared<Options>(week_starts_monday, count_from_zero,
373 first_week_is_fully_in_year);
374 }
375
376 if (func_name == "strptime") {
377 using Options = arrow::compute::StrptimeOptions;
378 return std::make_shared<Options>(
379 cpp11::as_cpp<std::string>(options["format"]),
380 cpp11::as_cpp<arrow::TimeUnit::type>(options["unit"]));
381 }
382
383 if (func_name == "strftime") {
384 using Options = arrow::compute::StrftimeOptions;
385 return std::make_shared<Options>(
386 Options(cpp11::as_cpp<std::string>(options["format"]),
387 cpp11::as_cpp<std::string>(options["locale"])));
388 }
389
390 if (func_name == "assume_timezone") {
391 using Options = arrow::compute::AssumeTimezoneOptions;
392 enum Options::Ambiguous ambiguous;
393 enum Options::Nonexistent nonexistent;
394
395 if (!Rf_isNull(options["ambiguous"])) {
396 ambiguous = cpp11::as_cpp<enum Options::Ambiguous>(options["ambiguous"]);
397 }
398 if (!Rf_isNull(options["nonexistent"])) {
399 nonexistent = cpp11::as_cpp<enum Options::Nonexistent>(options["nonexistent"]);
400 }
401
402 return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["timezone"]),
403 ambiguous, nonexistent);
404 }
405
406 if (func_name == "split_pattern" || func_name == "split_pattern_regex") {
407 using Options = arrow::compute::SplitPatternOptions;
408 int64_t max_splits = -1;
409 if (!Rf_isNull(options["max_splits"])) {
410 max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]);
411 }
412 bool reverse = false;
413 if (!Rf_isNull(options["reverse"])) {
414 reverse = cpp11::as_cpp<bool>(options["reverse"]);
415 }
416 return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
417 max_splits, reverse);
418 }
419
420 if (func_name == "utf8_lpad" || func_name == "utf8_rpad" ||
421 func_name == "utf8_center" || func_name == "ascii_lpad" ||
422 func_name == "ascii_rpad" || func_name == "ascii_center") {
423 using Options = arrow::compute::PadOptions;
424 return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["width"]),
425 cpp11::as_cpp<std::string>(options["padding"]));
426 }
427
428 if (func_name == "utf8_split_whitespace" || func_name == "ascii_split_whitespace") {
429 using Options = arrow::compute::SplitOptions;
430 int64_t max_splits = -1;
431 if (!Rf_isNull(options["max_splits"])) {
432 max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]);
433 }
434 bool reverse = false;
435 if (!Rf_isNull(options["reverse"])) {
436 reverse = cpp11::as_cpp<bool>(options["reverse"]);
437 }
438 return std::make_shared<Options>(max_splits, reverse);
439 }
440
441 if (func_name == "utf8_trim" || func_name == "utf8_ltrim" ||
442 func_name == "utf8_rtrim" || func_name == "ascii_trim" ||
443 func_name == "ascii_ltrim" || func_name == "ascii_rtrim") {
444 using Options = arrow::compute::TrimOptions;
445 return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["characters"]));
446 }
447
448 if (func_name == "utf8_slice_codeunits") {
449 using Options = arrow::compute::SliceOptions;
450
451 int64_t step = 1;
452 if (!Rf_isNull(options["step"])) {
453 step = cpp11::as_cpp<int64_t>(options["step"]);
454 }
455
456 int64_t stop = std::numeric_limits<int32_t>::max();
457 if (!Rf_isNull(options["stop"])) {
458 stop = cpp11::as_cpp<int64_t>(options["stop"]);
459 }
460
461 return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["start"]), stop,
462 step);
463 }
464
465 if (func_name == "utf8_replace_slice" || func_name == "binary_replace_slice") {
466 using Options = arrow::compute::ReplaceSliceOptions;
467
468 return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["start"]),
469 cpp11::as_cpp<int64_t>(options["stop"]),
470 cpp11::as_cpp<std::string>(options["replacement"]));
471 }
472
473 if (func_name == "variance" || func_name == "stddev" || func_name == "hash_variance" ||
474 func_name == "hash_stddev") {
475 using Options = arrow::compute::VarianceOptions;
476 auto out = std::make_shared<Options>();
477 out->ddof = cpp11::as_cpp<int64_t>(options["ddof"]);
478 if (!Rf_isNull(options["min_count"])) {
479 out->min_count = cpp11::as_cpp<int64_t>(options["min_count"]);
480 }
481 if (!Rf_isNull(options["skip_nulls"])) {
482 out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
483 }
484 return out;
485 }
486
487 if (func_name == "mode") {
488 using Options = arrow::compute::ModeOptions;
489 auto out = std::make_shared<Options>(Options::Defaults());
490 if (!Rf_isNull(options["n"])) {
491 out->n = cpp11::as_cpp<int64_t>(options["n"]);
492 }
493 if (!Rf_isNull(options["min_count"])) {
494 out->min_count = cpp11::as_cpp<uint32_t>(options["min_count"]);
495 }
496 if (!Rf_isNull(options["skip_nulls"])) {
497 out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
498 }
499 return out;
500 }
501
502 if (func_name == "partition_nth_indices") {
503 using Options = arrow::compute::PartitionNthOptions;
504 return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["pivot"]));
505 }
506
507 if (func_name == "round") {
508 using Options = arrow::compute::RoundOptions;
509 auto out = std::make_shared<Options>(Options::Defaults());
510 if (!Rf_isNull(options["ndigits"])) {
511 out->ndigits = cpp11::as_cpp<int64_t>(options["ndigits"]);
512 }
513 SEXP round_mode = options["round_mode"];
514 if (!Rf_isNull(round_mode)) {
515 out->round_mode = cpp11::as_cpp<enum arrow::compute::RoundMode>(round_mode);
516 }
517 return out;
518 }
519
520 if (func_name == "round_to_multiple") {
521 using Options = arrow::compute::RoundToMultipleOptions;
522 auto out = std::make_shared<Options>(Options::Defaults());
523 if (!Rf_isNull(options["multiple"])) {
524 out->multiple = std::make_shared<arrow::DoubleScalar>(
525 cpp11::as_cpp<double>(options["multiple"]));
526 }
527 SEXP round_mode = options["round_mode"];
528 if (!Rf_isNull(round_mode)) {
529 out->round_mode = cpp11::as_cpp<enum arrow::compute::RoundMode>(round_mode);
530 }
531 return out;
532 }
533
534 return nullptr;
535 }
536
537 std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options) {
538 using Options = arrow::compute::CastOptions;
539 auto out = std::make_shared<Options>(true);
540 SEXP to_type = options["to_type"];
541 if (!Rf_isNull(to_type) && cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type)) {
542 out->to_type = cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type);
543 }
544
545 SEXP allow_float_truncate = options["allow_float_truncate"];
546 if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp<bool>(allow_float_truncate)) {
547 out->allow_float_truncate = cpp11::as_cpp<bool>(allow_float_truncate);
548 }
549
550 SEXP allow_time_truncate = options["allow_time_truncate"];
551 if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp<bool>(allow_time_truncate)) {
552 out->allow_time_truncate = cpp11::as_cpp<bool>(allow_time_truncate);
553 }
554
555 SEXP allow_int_overflow = options["allow_int_overflow"];
556 if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp<bool>(allow_int_overflow)) {
557 out->allow_int_overflow = cpp11::as_cpp<bool>(allow_int_overflow);
558 }
559 return out;
560 }
561
562 // [[arrow::export]]
563 SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list options) {
564 auto opts = make_compute_options(func_name, options);
565 auto datum_args = arrow::r::from_r_list<arrow::Datum>(args);
566 auto out = ValueOrStop(
567 arrow::compute::CallFunction(func_name, datum_args, opts.get(), gc_context()));
568 return from_datum(std::move(out));
569 }
570
571 // [[arrow::export]]
572 std::vector<std::string> compute__GetFunctionNames() {
573 return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
574 }
575
576 #endif