1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "./arrow_types.h"
20 #if defined(ARROW_R_WITH_ARROW)
22 #include <arrow/compute/api.h>
23 #include <arrow/record_batch.h>
24 #include <arrow/table.h>
26 std::shared_ptr
<arrow::compute::CastOptions
> make_cast_options(cpp11::list options
);
28 arrow::compute::ExecContext
* gc_context() {
29 static arrow::compute::ExecContext
context(gc_memory_pool());
34 std::shared_ptr
<arrow::RecordBatch
> RecordBatch__cast(
35 const std::shared_ptr
<arrow::RecordBatch
>& batch
,
36 const std::shared_ptr
<arrow::Schema
>& schema
, cpp11::list options
) {
37 auto opts
= make_cast_options(options
);
38 auto nc
= batch
->num_columns();
40 arrow::ArrayVector
columns(nc
);
41 for (int i
= 0; i
< nc
; i
++) {
42 columns
[i
] = ValueOrStop(
43 arrow::compute::Cast(*batch
->column(i
), schema
->field(i
)->type(), *opts
));
46 return arrow::RecordBatch::Make(schema
, batch
->num_rows(), std::move(columns
));
50 std::shared_ptr
<arrow::Table
> Table__cast(const std::shared_ptr
<arrow::Table
>& table
,
51 const std::shared_ptr
<arrow::Schema
>& schema
,
52 cpp11::list options
) {
53 auto opts
= make_cast_options(options
);
54 auto nc
= table
->num_columns();
56 using ColumnVector
= std::vector
<std::shared_ptr
<arrow::ChunkedArray
>>;
57 ColumnVector
columns(nc
);
58 for (int i
= 0; i
< nc
; i
++) {
59 arrow::Datum
value(table
->column(i
));
61 ValueOrStop(arrow::compute::Cast(value
, schema
->field(i
)->type(), *opts
));
62 columns
[i
] = out
.chunked_array();
64 return arrow::Table::Make(schema
, std::move(columns
), table
->num_rows());
68 std::shared_ptr
<T
> MaybeUnbox(const char* class_name
, SEXP x
) {
69 if (Rf_inherits(x
, "ArrowObject") && Rf_inherits(x
, class_name
)) {
70 return cpp11::as_cpp
<std::shared_ptr
<T
>>(x
);
78 arrow::Datum as_cpp
<arrow::Datum
>(SEXP x
) {
79 if (auto array
= MaybeUnbox
<arrow::Array
>("Array", x
)) {
83 if (auto chunked_array
= MaybeUnbox
<arrow::ChunkedArray
>("ChunkedArray", x
)) {
87 if (auto batch
= MaybeUnbox
<arrow::RecordBatch
>("RecordBatch", x
)) {
91 if (auto table
= MaybeUnbox
<arrow::Table
>("Table", x
)) {
95 if (auto scalar
= MaybeUnbox
<arrow::Scalar
>("Scalar", x
)) {
99 // This assumes that R objects have already been converted to Arrow objects;
100 // that seems right but should we do the wrapping here too/instead?
101 cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x
)));
105 SEXP
from_datum(arrow::Datum datum
) {
106 switch (datum
.kind()) {
107 case arrow::Datum::SCALAR
:
108 return cpp11::to_r6(datum
.scalar());
110 case arrow::Datum::ARRAY
:
111 return cpp11::to_r6(datum
.make_array());
113 case arrow::Datum::CHUNKED_ARRAY
:
114 return cpp11::to_r6(datum
.chunked_array());
116 case arrow::Datum::RECORD_BATCH
:
117 return cpp11::to_r6(datum
.record_batch());
119 case arrow::Datum::TABLE
:
120 return cpp11::to_r6(datum
.table());
126 cpp11::stop("from_datum: Not implemented for Datum %s", datum
.ToString().c_str());
129 std::shared_ptr
<arrow::compute::FunctionOptions
> make_compute_options(
130 std::string func_name
, cpp11::list options
) {
131 if (func_name
== "filter") {
132 using Options
= arrow::compute::FilterOptions
;
133 auto out
= std::make_shared
<Options
>(Options::Defaults());
134 SEXP keep_na
= options
["keep_na"];
135 if (!Rf_isNull(keep_na
) && cpp11::as_cpp
<bool>(keep_na
)) {
136 out
->null_selection_behavior
= Options::EMIT_NULL
;
141 if (func_name
== "take") {
142 using Options
= arrow::compute::TakeOptions
;
143 auto out
= std::make_shared
<Options
>(Options::Defaults());
147 if (func_name
== "array_sort_indices") {
148 using Order
= arrow::compute::SortOrder
;
149 using Options
= arrow::compute::ArraySortOptions
;
150 // false means descending, true means ascending
151 auto order
= cpp11::as_cpp
<bool>(options
["order"]);
153 std::make_shared
<Options
>(Options(order
? Order::Descending
: Order::Ascending
));
157 if (func_name
== "sort_indices") {
158 using Key
= arrow::compute::SortKey
;
159 using Order
= arrow::compute::SortOrder
;
160 using Options
= arrow::compute::SortOptions
;
161 auto names
= cpp11::as_cpp
<std::vector
<std::string
>>(options
["names"]);
162 // false means descending, true means ascending
163 // cpp11 does not support bool here so use int
164 auto orders
= cpp11::as_cpp
<std::vector
<int>>(options
["orders"]);
165 std::vector
<Key
> keys
;
166 for (size_t i
= 0; i
< names
.size(); i
++) {
168 Key(names
[i
], (orders
[i
] > 0) ? Order::Descending
: Order::Ascending
));
170 auto out
= std::make_shared
<Options
>(Options(keys
));
174 if (func_name
== "all" || func_name
== "hash_all" || func_name
== "any" ||
175 func_name
== "hash_any" || func_name
== "approximate_median" ||
176 func_name
== "hash_approximate_median" || func_name
== "mean" ||
177 func_name
== "hash_mean" || func_name
== "min_max" || func_name
== "hash_min_max" ||
178 func_name
== "min" || func_name
== "hash_min" || func_name
== "max" ||
179 func_name
== "hash_max" || func_name
== "sum" || func_name
== "hash_sum") {
180 using Options
= arrow::compute::ScalarAggregateOptions
;
181 auto out
= std::make_shared
<Options
>(Options::Defaults());
182 if (!Rf_isNull(options
["min_count"])) {
183 out
->min_count
= cpp11::as_cpp
<int>(options
["min_count"]);
185 if (!Rf_isNull(options
["skip_nulls"])) {
186 out
->skip_nulls
= cpp11::as_cpp
<bool>(options
["skip_nulls"]);
191 if (func_name
== "tdigest" || func_name
== "hash_tdigest") {
192 using Options
= arrow::compute::TDigestOptions
;
193 auto out
= std::make_shared
<Options
>(Options::Defaults());
194 if (!Rf_isNull(options
["q"])) {
195 out
->q
= cpp11::as_cpp
<std::vector
<double>>(options
["q"]);
197 if (!Rf_isNull(options
["skip_nulls"])) {
198 out
->skip_nulls
= cpp11::as_cpp
<bool>(options
["skip_nulls"]);
203 if (func_name
== "count") {
204 using Options
= arrow::compute::CountOptions
;
205 auto out
= std::make_shared
<Options
>(Options::Defaults());
207 cpp11::as_cpp
<bool>(options
["na.rm"]) ? Options::ONLY_VALID
: Options::ONLY_NULL
;
211 if (func_name
== "count_distinct" || func_name
== "hash_count_distinct") {
212 using Options
= arrow::compute::CountOptions
;
213 auto out
= std::make_shared
<Options
>(Options::Defaults());
215 cpp11::as_cpp
<bool>(options
["na.rm"]) ? Options::ONLY_VALID
: Options::ALL
;
219 if (func_name
== "min_element_wise" || func_name
== "max_element_wise") {
220 using Options
= arrow::compute::ElementWiseAggregateOptions
;
221 bool skip_nulls
= true;
222 if (!Rf_isNull(options
["skip_nulls"])) {
223 skip_nulls
= cpp11::as_cpp
<bool>(options
["skip_nulls"]);
225 return std::make_shared
<Options
>(skip_nulls
);
228 if (func_name
== "quantile") {
229 using Options
= arrow::compute::QuantileOptions
;
230 auto out
= std::make_shared
<Options
>(Options::Defaults());
231 SEXP q
= options
["q"];
232 if (!Rf_isNull(q
) && TYPEOF(q
) == REALSXP
) {
233 out
->q
= cpp11::as_cpp
<std::vector
<double>>(q
);
235 SEXP interpolation
= options
["interpolation"];
236 if (!Rf_isNull(interpolation
) && TYPEOF(interpolation
) == INTSXP
&&
237 XLENGTH(interpolation
) == 1) {
239 cpp11::as_cpp
<enum arrow::compute::QuantileOptions::Interpolation
>(
242 if (!Rf_isNull(options
["min_count"])) {
243 out
->min_count
= cpp11::as_cpp
<int64_t>(options
["min_count"]);
245 if (!Rf_isNull(options
["skip_nulls"])) {
246 out
->skip_nulls
= cpp11::as_cpp
<int64_t>(options
["skip_nulls"]);
251 if (func_name
== "is_in" || func_name
== "index_in") {
252 using Options
= arrow::compute::SetLookupOptions
;
253 return std::make_shared
<Options
>(cpp11::as_cpp
<arrow::Datum
>(options
["value_set"]),
254 cpp11::as_cpp
<bool>(options
["skip_nulls"]));
257 if (func_name
== "index") {
258 using Options
= arrow::compute::IndexOptions
;
259 return std::make_shared
<Options
>(
260 cpp11::as_cpp
<std::shared_ptr
<arrow::Scalar
>>(options
["value"]));
263 if (func_name
== "is_null") {
264 using Options
= arrow::compute::NullOptions
;
265 auto out
= std::make_shared
<Options
>(Options::Defaults());
266 if (!Rf_isNull(options
["nan_is_null"])) {
267 out
->nan_is_null
= cpp11::as_cpp
<bool>(options
["nan_is_null"]);
272 if (func_name
== "dictionary_encode") {
273 using Options
= arrow::compute::DictionaryEncodeOptions
;
274 auto out
= std::make_shared
<Options
>(Options::Defaults());
275 if (!Rf_isNull(options
["null_encoding_behavior"])) {
276 out
->null_encoding_behavior
= cpp11::as_cpp
<
277 enum arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior
>(
278 options
["null_encoding_behavior"]);
283 if (func_name
== "cast") {
284 return make_cast_options(options
);
287 if (func_name
== "binary_join_element_wise") {
288 using Options
= arrow::compute::JoinOptions
;
289 auto out
= std::make_shared
<Options
>(Options::Defaults());
290 if (!Rf_isNull(options
["null_handling"])) {
292 cpp11::as_cpp
<enum arrow::compute::JoinOptions::NullHandlingBehavior
>(
293 options
["null_handling"]);
295 if (!Rf_isNull(options
["null_replacement"])) {
296 out
->null_replacement
= cpp11::as_cpp
<std::string
>(options
["null_replacement"]);
301 if (func_name
== "make_struct") {
302 using Options
= arrow::compute::MakeStructOptions
;
303 // TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options
304 return std::make_shared
<Options
>(
305 cpp11::as_cpp
<std::vector
<std::string
>>(options
["field_names"]));
308 if (func_name
== "match_substring" || func_name
== "match_substring_regex" ||
309 func_name
== "find_substring" || func_name
== "find_substring_regex" ||
310 func_name
== "match_like" || func_name
== "starts_with" ||
311 func_name
== "ends_with" || func_name
== "count_substring" ||
312 func_name
== "count_substring_regex") {
313 using Options
= arrow::compute::MatchSubstringOptions
;
314 bool ignore_case
= false;
315 if (!Rf_isNull(options
["ignore_case"])) {
316 ignore_case
= cpp11::as_cpp
<bool>(options
["ignore_case"]);
318 return std::make_shared
<Options
>(cpp11::as_cpp
<std::string
>(options
["pattern"]),
322 if (func_name
== "replace_substring" || func_name
== "replace_substring_regex") {
323 using Options
= arrow::compute::ReplaceSubstringOptions
;
324 int64_t max_replacements
= -1;
325 if (!Rf_isNull(options
["max_replacements"])) {
326 max_replacements
= cpp11::as_cpp
<int64_t>(options
["max_replacements"]);
328 return std::make_shared
<Options
>(cpp11::as_cpp
<std::string
>(options
["pattern"]),
329 cpp11::as_cpp
<std::string
>(options
["replacement"]),
333 if (func_name
== "extract_regex") {
334 using Options
= arrow::compute::ExtractRegexOptions
;
335 return std::make_shared
<Options
>(cpp11::as_cpp
<std::string
>(options
["pattern"]));
338 if (func_name
== "day_of_week") {
339 using Options
= arrow::compute::DayOfWeekOptions
;
340 bool count_from_zero
= false;
341 if (!Rf_isNull(options
["count_from_zero"])) {
342 count_from_zero
= cpp11::as_cpp
<bool>(options
["count_from_zero"]);
344 return std::make_shared
<Options
>(count_from_zero
,
345 cpp11::as_cpp
<uint32_t>(options
["week_start"]));
348 if (func_name
== "iso_week") {
349 return std::make_shared
<arrow::compute::WeekOptions
>(
350 arrow::compute::WeekOptions::ISODefaults());
353 if (func_name
== "us_week") {
354 return std::make_shared
<arrow::compute::WeekOptions
>(
355 arrow::compute::WeekOptions::USDefaults());
358 if (func_name
== "week") {
359 using Options
= arrow::compute::WeekOptions
;
360 bool week_starts_monday
= true;
361 bool count_from_zero
= false;
362 bool first_week_is_fully_in_year
= false;
363 if (!Rf_isNull(options
["week_starts_monday"])) {
364 week_starts_monday
= cpp11::as_cpp
<bool>(options
["week_starts_monday"]);
366 if (!Rf_isNull(options
["count_from_zero"])) {
367 count_from_zero
= cpp11::as_cpp
<bool>(options
["count_from_zero"]);
369 if (!Rf_isNull(options
["first_week_is_fully_in_year"])) {
370 count_from_zero
= cpp11::as_cpp
<bool>(options
["first_week_is_fully_in_year"]);
372 return std::make_shared
<Options
>(week_starts_monday
, count_from_zero
,
373 first_week_is_fully_in_year
);
376 if (func_name
== "strptime") {
377 using Options
= arrow::compute::StrptimeOptions
;
378 return std::make_shared
<Options
>(
379 cpp11::as_cpp
<std::string
>(options
["format"]),
380 cpp11::as_cpp
<arrow::TimeUnit::type
>(options
["unit"]));
383 if (func_name
== "strftime") {
384 using Options
= arrow::compute::StrftimeOptions
;
385 return std::make_shared
<Options
>(
386 Options(cpp11::as_cpp
<std::string
>(options
["format"]),
387 cpp11::as_cpp
<std::string
>(options
["locale"])));
390 if (func_name
== "assume_timezone") {
391 using Options
= arrow::compute::AssumeTimezoneOptions
;
392 enum Options::Ambiguous ambiguous
;
393 enum Options::Nonexistent nonexistent
;
395 if (!Rf_isNull(options
["ambiguous"])) {
396 ambiguous
= cpp11::as_cpp
<enum Options::Ambiguous
>(options
["ambiguous"]);
398 if (!Rf_isNull(options
["nonexistent"])) {
399 nonexistent
= cpp11::as_cpp
<enum Options::Nonexistent
>(options
["nonexistent"]);
402 return std::make_shared
<Options
>(cpp11::as_cpp
<std::string
>(options
["timezone"]),
403 ambiguous
, nonexistent
);
406 if (func_name
== "split_pattern" || func_name
== "split_pattern_regex") {
407 using Options
= arrow::compute::SplitPatternOptions
;
408 int64_t max_splits
= -1;
409 if (!Rf_isNull(options
["max_splits"])) {
410 max_splits
= cpp11::as_cpp
<int64_t>(options
["max_splits"]);
412 bool reverse
= false;
413 if (!Rf_isNull(options
["reverse"])) {
414 reverse
= cpp11::as_cpp
<bool>(options
["reverse"]);
416 return std::make_shared
<Options
>(cpp11::as_cpp
<std::string
>(options
["pattern"]),
417 max_splits
, reverse
);
420 if (func_name
== "utf8_lpad" || func_name
== "utf8_rpad" ||
421 func_name
== "utf8_center" || func_name
== "ascii_lpad" ||
422 func_name
== "ascii_rpad" || func_name
== "ascii_center") {
423 using Options
= arrow::compute::PadOptions
;
424 return std::make_shared
<Options
>(cpp11::as_cpp
<int64_t>(options
["width"]),
425 cpp11::as_cpp
<std::string
>(options
["padding"]));
428 if (func_name
== "utf8_split_whitespace" || func_name
== "ascii_split_whitespace") {
429 using Options
= arrow::compute::SplitOptions
;
430 int64_t max_splits
= -1;
431 if (!Rf_isNull(options
["max_splits"])) {
432 max_splits
= cpp11::as_cpp
<int64_t>(options
["max_splits"]);
434 bool reverse
= false;
435 if (!Rf_isNull(options
["reverse"])) {
436 reverse
= cpp11::as_cpp
<bool>(options
["reverse"]);
438 return std::make_shared
<Options
>(max_splits
, reverse
);
441 if (func_name
== "utf8_trim" || func_name
== "utf8_ltrim" ||
442 func_name
== "utf8_rtrim" || func_name
== "ascii_trim" ||
443 func_name
== "ascii_ltrim" || func_name
== "ascii_rtrim") {
444 using Options
= arrow::compute::TrimOptions
;
445 return std::make_shared
<Options
>(cpp11::as_cpp
<std::string
>(options
["characters"]));
448 if (func_name
== "utf8_slice_codeunits") {
449 using Options
= arrow::compute::SliceOptions
;
452 if (!Rf_isNull(options
["step"])) {
453 step
= cpp11::as_cpp
<int64_t>(options
["step"]);
456 int64_t stop
= std::numeric_limits
<int32_t>::max();
457 if (!Rf_isNull(options
["stop"])) {
458 stop
= cpp11::as_cpp
<int64_t>(options
["stop"]);
461 return std::make_shared
<Options
>(cpp11::as_cpp
<int64_t>(options
["start"]), stop
,
465 if (func_name
== "utf8_replace_slice" || func_name
== "binary_replace_slice") {
466 using Options
= arrow::compute::ReplaceSliceOptions
;
468 return std::make_shared
<Options
>(cpp11::as_cpp
<int64_t>(options
["start"]),
469 cpp11::as_cpp
<int64_t>(options
["stop"]),
470 cpp11::as_cpp
<std::string
>(options
["replacement"]));
473 if (func_name
== "variance" || func_name
== "stddev" || func_name
== "hash_variance" ||
474 func_name
== "hash_stddev") {
475 using Options
= arrow::compute::VarianceOptions
;
476 auto out
= std::make_shared
<Options
>();
477 out
->ddof
= cpp11::as_cpp
<int64_t>(options
["ddof"]);
478 if (!Rf_isNull(options
["min_count"])) {
479 out
->min_count
= cpp11::as_cpp
<int64_t>(options
["min_count"]);
481 if (!Rf_isNull(options
["skip_nulls"])) {
482 out
->skip_nulls
= cpp11::as_cpp
<bool>(options
["skip_nulls"]);
487 if (func_name
== "mode") {
488 using Options
= arrow::compute::ModeOptions
;
489 auto out
= std::make_shared
<Options
>(Options::Defaults());
490 if (!Rf_isNull(options
["n"])) {
491 out
->n
= cpp11::as_cpp
<int64_t>(options
["n"]);
493 if (!Rf_isNull(options
["min_count"])) {
494 out
->min_count
= cpp11::as_cpp
<uint32_t>(options
["min_count"]);
496 if (!Rf_isNull(options
["skip_nulls"])) {
497 out
->skip_nulls
= cpp11::as_cpp
<bool>(options
["skip_nulls"]);
502 if (func_name
== "partition_nth_indices") {
503 using Options
= arrow::compute::PartitionNthOptions
;
504 return std::make_shared
<Options
>(cpp11::as_cpp
<int64_t>(options
["pivot"]));
507 if (func_name
== "round") {
508 using Options
= arrow::compute::RoundOptions
;
509 auto out
= std::make_shared
<Options
>(Options::Defaults());
510 if (!Rf_isNull(options
["ndigits"])) {
511 out
->ndigits
= cpp11::as_cpp
<int64_t>(options
["ndigits"]);
513 SEXP round_mode
= options
["round_mode"];
514 if (!Rf_isNull(round_mode
)) {
515 out
->round_mode
= cpp11::as_cpp
<enum arrow::compute::RoundMode
>(round_mode
);
520 if (func_name
== "round_to_multiple") {
521 using Options
= arrow::compute::RoundToMultipleOptions
;
522 auto out
= std::make_shared
<Options
>(Options::Defaults());
523 if (!Rf_isNull(options
["multiple"])) {
524 out
->multiple
= std::make_shared
<arrow::DoubleScalar
>(
525 cpp11::as_cpp
<double>(options
["multiple"]));
527 SEXP round_mode
= options
["round_mode"];
528 if (!Rf_isNull(round_mode
)) {
529 out
->round_mode
= cpp11::as_cpp
<enum arrow::compute::RoundMode
>(round_mode
);
537 std::shared_ptr
<arrow::compute::CastOptions
> make_cast_options(cpp11::list options
) {
538 using Options
= arrow::compute::CastOptions
;
539 auto out
= std::make_shared
<Options
>(true);
540 SEXP to_type
= options
["to_type"];
541 if (!Rf_isNull(to_type
) && cpp11::as_cpp
<std::shared_ptr
<arrow::DataType
>>(to_type
)) {
542 out
->to_type
= cpp11::as_cpp
<std::shared_ptr
<arrow::DataType
>>(to_type
);
545 SEXP allow_float_truncate
= options
["allow_float_truncate"];
546 if (!Rf_isNull(allow_float_truncate
) && cpp11::as_cpp
<bool>(allow_float_truncate
)) {
547 out
->allow_float_truncate
= cpp11::as_cpp
<bool>(allow_float_truncate
);
550 SEXP allow_time_truncate
= options
["allow_time_truncate"];
551 if (!Rf_isNull(allow_time_truncate
) && cpp11::as_cpp
<bool>(allow_time_truncate
)) {
552 out
->allow_time_truncate
= cpp11::as_cpp
<bool>(allow_time_truncate
);
555 SEXP allow_int_overflow
= options
["allow_int_overflow"];
556 if (!Rf_isNull(allow_int_overflow
) && cpp11::as_cpp
<bool>(allow_int_overflow
)) {
557 out
->allow_int_overflow
= cpp11::as_cpp
<bool>(allow_int_overflow
);
563 SEXP
compute__CallFunction(std::string func_name
, cpp11::list args
, cpp11::list options
) {
564 auto opts
= make_compute_options(func_name
, options
);
565 auto datum_args
= arrow::r::from_r_list
<arrow::Datum
>(args
);
566 auto out
= ValueOrStop(
567 arrow::compute::CallFunction(func_name
, datum_args
, opts
.get(), gc_context()));
568 return from_datum(std::move(out
));
572 std::vector
<std::string
> compute__GetFunctionNames() {
573 return arrow::compute::GetFunctionRegistry()->GetFunctionNames();