]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "./arrow_types.h" | |
19 | ||
20 | #if defined(ARROW_R_WITH_ARROW) | |
21 | ||
22 | #include <arrow/compute/api.h> | |
23 | #include <arrow/record_batch.h> | |
24 | #include <arrow/table.h> | |
25 | ||
26 | std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options); | |
27 | ||
28 | arrow::compute::ExecContext* gc_context() { | |
29 | static arrow::compute::ExecContext context(gc_memory_pool()); | |
30 | return &context; | |
31 | } | |
32 | ||
33 | // [[arrow::export]] | |
34 | std::shared_ptr<arrow::RecordBatch> RecordBatch__cast( | |
35 | const std::shared_ptr<arrow::RecordBatch>& batch, | |
36 | const std::shared_ptr<arrow::Schema>& schema, cpp11::list options) { | |
37 | auto opts = make_cast_options(options); | |
38 | auto nc = batch->num_columns(); | |
39 | ||
40 | arrow::ArrayVector columns(nc); | |
41 | for (int i = 0; i < nc; i++) { | |
42 | columns[i] = ValueOrStop( | |
43 | arrow::compute::Cast(*batch->column(i), schema->field(i)->type(), *opts)); | |
44 | } | |
45 | ||
46 | return arrow::RecordBatch::Make(schema, batch->num_rows(), std::move(columns)); | |
47 | } | |
48 | ||
49 | // [[arrow::export]] | |
50 | std::shared_ptr<arrow::Table> Table__cast(const std::shared_ptr<arrow::Table>& table, | |
51 | const std::shared_ptr<arrow::Schema>& schema, | |
52 | cpp11::list options) { | |
53 | auto opts = make_cast_options(options); | |
54 | auto nc = table->num_columns(); | |
55 | ||
56 | using ColumnVector = std::vector<std::shared_ptr<arrow::ChunkedArray>>; | |
57 | ColumnVector columns(nc); | |
58 | for (int i = 0; i < nc; i++) { | |
59 | arrow::Datum value(table->column(i)); | |
60 | arrow::Datum out = | |
61 | ValueOrStop(arrow::compute::Cast(value, schema->field(i)->type(), *opts)); | |
62 | columns[i] = out.chunked_array(); | |
63 | } | |
64 | return arrow::Table::Make(schema, std::move(columns), table->num_rows()); | |
65 | } | |
66 | ||
67 | template <typename T> | |
68 | std::shared_ptr<T> MaybeUnbox(const char* class_name, SEXP x) { | |
69 | if (Rf_inherits(x, "ArrowObject") && Rf_inherits(x, class_name)) { | |
70 | return cpp11::as_cpp<std::shared_ptr<T>>(x); | |
71 | } | |
72 | return nullptr; | |
73 | } | |
74 | ||
75 | namespace cpp11 { | |
76 | ||
77 | template <> | |
78 | arrow::Datum as_cpp<arrow::Datum>(SEXP x) { | |
79 | if (auto array = MaybeUnbox<arrow::Array>("Array", x)) { | |
80 | return array; | |
81 | } | |
82 | ||
83 | if (auto chunked_array = MaybeUnbox<arrow::ChunkedArray>("ChunkedArray", x)) { | |
84 | return chunked_array; | |
85 | } | |
86 | ||
87 | if (auto batch = MaybeUnbox<arrow::RecordBatch>("RecordBatch", x)) { | |
88 | return batch; | |
89 | } | |
90 | ||
91 | if (auto table = MaybeUnbox<arrow::Table>("Table", x)) { | |
92 | return table; | |
93 | } | |
94 | ||
95 | if (auto scalar = MaybeUnbox<arrow::Scalar>("Scalar", x)) { | |
96 | return scalar; | |
97 | } | |
98 | ||
99 | // This assumes that R objects have already been converted to Arrow objects; | |
100 | // that seems right but should we do the wrapping here too/instead? | |
101 | cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x))); | |
102 | } | |
103 | } // namespace cpp11 | |
104 | ||
105 | SEXP from_datum(arrow::Datum datum) { | |
106 | switch (datum.kind()) { | |
107 | case arrow::Datum::SCALAR: | |
108 | return cpp11::to_r6(datum.scalar()); | |
109 | ||
110 | case arrow::Datum::ARRAY: | |
111 | return cpp11::to_r6(datum.make_array()); | |
112 | ||
113 | case arrow::Datum::CHUNKED_ARRAY: | |
114 | return cpp11::to_r6(datum.chunked_array()); | |
115 | ||
116 | case arrow::Datum::RECORD_BATCH: | |
117 | return cpp11::to_r6(datum.record_batch()); | |
118 | ||
119 | case arrow::Datum::TABLE: | |
120 | return cpp11::to_r6(datum.table()); | |
121 | ||
122 | default: | |
123 | break; | |
124 | } | |
125 | ||
126 | cpp11::stop("from_datum: Not implemented for Datum %s", datum.ToString().c_str()); | |
127 | } | |
128 | ||
129 | std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options( | |
130 | std::string func_name, cpp11::list options) { | |
131 | if (func_name == "filter") { | |
132 | using Options = arrow::compute::FilterOptions; | |
133 | auto out = std::make_shared<Options>(Options::Defaults()); | |
134 | SEXP keep_na = options["keep_na"]; | |
135 | if (!Rf_isNull(keep_na) && cpp11::as_cpp<bool>(keep_na)) { | |
136 | out->null_selection_behavior = Options::EMIT_NULL; | |
137 | } | |
138 | return out; | |
139 | } | |
140 | ||
141 | if (func_name == "take") { | |
142 | using Options = arrow::compute::TakeOptions; | |
143 | auto out = std::make_shared<Options>(Options::Defaults()); | |
144 | return out; | |
145 | } | |
146 | ||
147 | if (func_name == "array_sort_indices") { | |
148 | using Order = arrow::compute::SortOrder; | |
149 | using Options = arrow::compute::ArraySortOptions; | |
150 | // false means descending, true means ascending | |
151 | auto order = cpp11::as_cpp<bool>(options["order"]); | |
152 | auto out = | |
153 | std::make_shared<Options>(Options(order ? Order::Descending : Order::Ascending)); | |
154 | return out; | |
155 | } | |
156 | ||
157 | if (func_name == "sort_indices") { | |
158 | using Key = arrow::compute::SortKey; | |
159 | using Order = arrow::compute::SortOrder; | |
160 | using Options = arrow::compute::SortOptions; | |
161 | auto names = cpp11::as_cpp<std::vector<std::string>>(options["names"]); | |
162 | // false means descending, true means ascending | |
163 | // cpp11 does not support bool here so use int | |
164 | auto orders = cpp11::as_cpp<std::vector<int>>(options["orders"]); | |
165 | std::vector<Key> keys; | |
166 | for (size_t i = 0; i < names.size(); i++) { | |
167 | keys.push_back( | |
168 | Key(names[i], (orders[i] > 0) ? Order::Descending : Order::Ascending)); | |
169 | } | |
170 | auto out = std::make_shared<Options>(Options(keys)); | |
171 | return out; | |
172 | } | |
173 | ||
174 | if (func_name == "all" || func_name == "hash_all" || func_name == "any" || | |
175 | func_name == "hash_any" || func_name == "approximate_median" || | |
176 | func_name == "hash_approximate_median" || func_name == "mean" || | |
177 | func_name == "hash_mean" || func_name == "min_max" || func_name == "hash_min_max" || | |
178 | func_name == "min" || func_name == "hash_min" || func_name == "max" || | |
179 | func_name == "hash_max" || func_name == "sum" || func_name == "hash_sum") { | |
180 | using Options = arrow::compute::ScalarAggregateOptions; | |
181 | auto out = std::make_shared<Options>(Options::Defaults()); | |
182 | if (!Rf_isNull(options["min_count"])) { | |
183 | out->min_count = cpp11::as_cpp<int>(options["min_count"]); | |
184 | } | |
185 | if (!Rf_isNull(options["skip_nulls"])) { | |
186 | out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]); | |
187 | } | |
188 | return out; | |
189 | } | |
190 | ||
191 | if (func_name == "tdigest" || func_name == "hash_tdigest") { | |
192 | using Options = arrow::compute::TDigestOptions; | |
193 | auto out = std::make_shared<Options>(Options::Defaults()); | |
194 | if (!Rf_isNull(options["q"])) { | |
195 | out->q = cpp11::as_cpp<std::vector<double>>(options["q"]); | |
196 | } | |
197 | if (!Rf_isNull(options["skip_nulls"])) { | |
198 | out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]); | |
199 | } | |
200 | return out; | |
201 | } | |
202 | ||
203 | if (func_name == "count") { | |
204 | using Options = arrow::compute::CountOptions; | |
205 | auto out = std::make_shared<Options>(Options::Defaults()); | |
206 | out->mode = | |
207 | cpp11::as_cpp<bool>(options["na.rm"]) ? Options::ONLY_VALID : Options::ONLY_NULL; | |
208 | return out; | |
209 | } | |
210 | ||
211 | if (func_name == "count_distinct" || func_name == "hash_count_distinct") { | |
212 | using Options = arrow::compute::CountOptions; | |
213 | auto out = std::make_shared<Options>(Options::Defaults()); | |
214 | out->mode = | |
215 | cpp11::as_cpp<bool>(options["na.rm"]) ? Options::ONLY_VALID : Options::ALL; | |
216 | return out; | |
217 | } | |
218 | ||
219 | if (func_name == "min_element_wise" || func_name == "max_element_wise") { | |
220 | using Options = arrow::compute::ElementWiseAggregateOptions; | |
221 | bool skip_nulls = true; | |
222 | if (!Rf_isNull(options["skip_nulls"])) { | |
223 | skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]); | |
224 | } | |
225 | return std::make_shared<Options>(skip_nulls); | |
226 | } | |
227 | ||
228 | if (func_name == "quantile") { | |
229 | using Options = arrow::compute::QuantileOptions; | |
230 | auto out = std::make_shared<Options>(Options::Defaults()); | |
231 | SEXP q = options["q"]; | |
232 | if (!Rf_isNull(q) && TYPEOF(q) == REALSXP) { | |
233 | out->q = cpp11::as_cpp<std::vector<double>>(q); | |
234 | } | |
235 | SEXP interpolation = options["interpolation"]; | |
236 | if (!Rf_isNull(interpolation) && TYPEOF(interpolation) == INTSXP && | |
237 | XLENGTH(interpolation) == 1) { | |
238 | out->interpolation = | |
239 | cpp11::as_cpp<enum arrow::compute::QuantileOptions::Interpolation>( | |
240 | interpolation); | |
241 | } | |
242 | if (!Rf_isNull(options["min_count"])) { | |
243 | out->min_count = cpp11::as_cpp<int64_t>(options["min_count"]); | |
244 | } | |
245 | if (!Rf_isNull(options["skip_nulls"])) { | |
246 | out->skip_nulls = cpp11::as_cpp<int64_t>(options["skip_nulls"]); | |
247 | } | |
248 | return out; | |
249 | } | |
250 | ||
251 | if (func_name == "is_in" || func_name == "index_in") { | |
252 | using Options = arrow::compute::SetLookupOptions; | |
253 | return std::make_shared<Options>(cpp11::as_cpp<arrow::Datum>(options["value_set"]), | |
254 | cpp11::as_cpp<bool>(options["skip_nulls"])); | |
255 | } | |
256 | ||
257 | if (func_name == "index") { | |
258 | using Options = arrow::compute::IndexOptions; | |
259 | return std::make_shared<Options>( | |
260 | cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(options["value"])); | |
261 | } | |
262 | ||
263 | if (func_name == "is_null") { | |
264 | using Options = arrow::compute::NullOptions; | |
265 | auto out = std::make_shared<Options>(Options::Defaults()); | |
266 | if (!Rf_isNull(options["nan_is_null"])) { | |
267 | out->nan_is_null = cpp11::as_cpp<bool>(options["nan_is_null"]); | |
268 | } | |
269 | return out; | |
270 | } | |
271 | ||
272 | if (func_name == "dictionary_encode") { | |
273 | using Options = arrow::compute::DictionaryEncodeOptions; | |
274 | auto out = std::make_shared<Options>(Options::Defaults()); | |
275 | if (!Rf_isNull(options["null_encoding_behavior"])) { | |
276 | out->null_encoding_behavior = cpp11::as_cpp< | |
277 | enum arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior>( | |
278 | options["null_encoding_behavior"]); | |
279 | } | |
280 | return out; | |
281 | } | |
282 | ||
283 | if (func_name == "cast") { | |
284 | return make_cast_options(options); | |
285 | } | |
286 | ||
287 | if (func_name == "binary_join_element_wise") { | |
288 | using Options = arrow::compute::JoinOptions; | |
289 | auto out = std::make_shared<Options>(Options::Defaults()); | |
290 | if (!Rf_isNull(options["null_handling"])) { | |
291 | out->null_handling = | |
292 | cpp11::as_cpp<enum arrow::compute::JoinOptions::NullHandlingBehavior>( | |
293 | options["null_handling"]); | |
294 | } | |
295 | if (!Rf_isNull(options["null_replacement"])) { | |
296 | out->null_replacement = cpp11::as_cpp<std::string>(options["null_replacement"]); | |
297 | } | |
298 | return out; | |
299 | } | |
300 | ||
301 | if (func_name == "make_struct") { | |
302 | using Options = arrow::compute::MakeStructOptions; | |
303 | // TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options | |
304 | return std::make_shared<Options>( | |
305 | cpp11::as_cpp<std::vector<std::string>>(options["field_names"])); | |
306 | } | |
307 | ||
308 | if (func_name == "match_substring" || func_name == "match_substring_regex" || | |
309 | func_name == "find_substring" || func_name == "find_substring_regex" || | |
310 | func_name == "match_like" || func_name == "starts_with" || | |
311 | func_name == "ends_with" || func_name == "count_substring" || | |
312 | func_name == "count_substring_regex") { | |
313 | using Options = arrow::compute::MatchSubstringOptions; | |
314 | bool ignore_case = false; | |
315 | if (!Rf_isNull(options["ignore_case"])) { | |
316 | ignore_case = cpp11::as_cpp<bool>(options["ignore_case"]); | |
317 | } | |
318 | return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]), | |
319 | ignore_case); | |
320 | } | |
321 | ||
322 | if (func_name == "replace_substring" || func_name == "replace_substring_regex") { | |
323 | using Options = arrow::compute::ReplaceSubstringOptions; | |
324 | int64_t max_replacements = -1; | |
325 | if (!Rf_isNull(options["max_replacements"])) { | |
326 | max_replacements = cpp11::as_cpp<int64_t>(options["max_replacements"]); | |
327 | } | |
328 | return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]), | |
329 | cpp11::as_cpp<std::string>(options["replacement"]), | |
330 | max_replacements); | |
331 | } | |
332 | ||
333 | if (func_name == "extract_regex") { | |
334 | using Options = arrow::compute::ExtractRegexOptions; | |
335 | return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"])); | |
336 | } | |
337 | ||
338 | if (func_name == "day_of_week") { | |
339 | using Options = arrow::compute::DayOfWeekOptions; | |
340 | bool count_from_zero = false; | |
341 | if (!Rf_isNull(options["count_from_zero"])) { | |
342 | count_from_zero = cpp11::as_cpp<bool>(options["count_from_zero"]); | |
343 | } | |
344 | return std::make_shared<Options>(count_from_zero, | |
345 | cpp11::as_cpp<uint32_t>(options["week_start"])); | |
346 | } | |
347 | ||
348 | if (func_name == "iso_week") { | |
349 | return std::make_shared<arrow::compute::WeekOptions>( | |
350 | arrow::compute::WeekOptions::ISODefaults()); | |
351 | } | |
352 | ||
353 | if (func_name == "us_week") { | |
354 | return std::make_shared<arrow::compute::WeekOptions>( | |
355 | arrow::compute::WeekOptions::USDefaults()); | |
356 | } | |
357 | ||
358 | if (func_name == "week") { | |
359 | using Options = arrow::compute::WeekOptions; | |
360 | bool week_starts_monday = true; | |
361 | bool count_from_zero = false; | |
362 | bool first_week_is_fully_in_year = false; | |
363 | if (!Rf_isNull(options["week_starts_monday"])) { | |
364 | week_starts_monday = cpp11::as_cpp<bool>(options["week_starts_monday"]); | |
365 | } | |
366 | if (!Rf_isNull(options["count_from_zero"])) { | |
367 | count_from_zero = cpp11::as_cpp<bool>(options["count_from_zero"]); | |
368 | } | |
369 | if (!Rf_isNull(options["first_week_is_fully_in_year"])) { | |
370 | count_from_zero = cpp11::as_cpp<bool>(options["first_week_is_fully_in_year"]); | |
371 | } | |
372 | return std::make_shared<Options>(week_starts_monday, count_from_zero, | |
373 | first_week_is_fully_in_year); | |
374 | } | |
375 | ||
376 | if (func_name == "strptime") { | |
377 | using Options = arrow::compute::StrptimeOptions; | |
378 | return std::make_shared<Options>( | |
379 | cpp11::as_cpp<std::string>(options["format"]), | |
380 | cpp11::as_cpp<arrow::TimeUnit::type>(options["unit"])); | |
381 | } | |
382 | ||
383 | if (func_name == "strftime") { | |
384 | using Options = arrow::compute::StrftimeOptions; | |
385 | return std::make_shared<Options>( | |
386 | Options(cpp11::as_cpp<std::string>(options["format"]), | |
387 | cpp11::as_cpp<std::string>(options["locale"]))); | |
388 | } | |
389 | ||
390 | if (func_name == "assume_timezone") { | |
391 | using Options = arrow::compute::AssumeTimezoneOptions; | |
392 | enum Options::Ambiguous ambiguous; | |
393 | enum Options::Nonexistent nonexistent; | |
394 | ||
395 | if (!Rf_isNull(options["ambiguous"])) { | |
396 | ambiguous = cpp11::as_cpp<enum Options::Ambiguous>(options["ambiguous"]); | |
397 | } | |
398 | if (!Rf_isNull(options["nonexistent"])) { | |
399 | nonexistent = cpp11::as_cpp<enum Options::Nonexistent>(options["nonexistent"]); | |
400 | } | |
401 | ||
402 | return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["timezone"]), | |
403 | ambiguous, nonexistent); | |
404 | } | |
405 | ||
406 | if (func_name == "split_pattern" || func_name == "split_pattern_regex") { | |
407 | using Options = arrow::compute::SplitPatternOptions; | |
408 | int64_t max_splits = -1; | |
409 | if (!Rf_isNull(options["max_splits"])) { | |
410 | max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]); | |
411 | } | |
412 | bool reverse = false; | |
413 | if (!Rf_isNull(options["reverse"])) { | |
414 | reverse = cpp11::as_cpp<bool>(options["reverse"]); | |
415 | } | |
416 | return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]), | |
417 | max_splits, reverse); | |
418 | } | |
419 | ||
420 | if (func_name == "utf8_lpad" || func_name == "utf8_rpad" || | |
421 | func_name == "utf8_center" || func_name == "ascii_lpad" || | |
422 | func_name == "ascii_rpad" || func_name == "ascii_center") { | |
423 | using Options = arrow::compute::PadOptions; | |
424 | return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["width"]), | |
425 | cpp11::as_cpp<std::string>(options["padding"])); | |
426 | } | |
427 | ||
428 | if (func_name == "utf8_split_whitespace" || func_name == "ascii_split_whitespace") { | |
429 | using Options = arrow::compute::SplitOptions; | |
430 | int64_t max_splits = -1; | |
431 | if (!Rf_isNull(options["max_splits"])) { | |
432 | max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]); | |
433 | } | |
434 | bool reverse = false; | |
435 | if (!Rf_isNull(options["reverse"])) { | |
436 | reverse = cpp11::as_cpp<bool>(options["reverse"]); | |
437 | } | |
438 | return std::make_shared<Options>(max_splits, reverse); | |
439 | } | |
440 | ||
441 | if (func_name == "utf8_trim" || func_name == "utf8_ltrim" || | |
442 | func_name == "utf8_rtrim" || func_name == "ascii_trim" || | |
443 | func_name == "ascii_ltrim" || func_name == "ascii_rtrim") { | |
444 | using Options = arrow::compute::TrimOptions; | |
445 | return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["characters"])); | |
446 | } | |
447 | ||
448 | if (func_name == "utf8_slice_codeunits") { | |
449 | using Options = arrow::compute::SliceOptions; | |
450 | ||
451 | int64_t step = 1; | |
452 | if (!Rf_isNull(options["step"])) { | |
453 | step = cpp11::as_cpp<int64_t>(options["step"]); | |
454 | } | |
455 | ||
456 | int64_t stop = std::numeric_limits<int32_t>::max(); | |
457 | if (!Rf_isNull(options["stop"])) { | |
458 | stop = cpp11::as_cpp<int64_t>(options["stop"]); | |
459 | } | |
460 | ||
461 | return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["start"]), stop, | |
462 | step); | |
463 | } | |
464 | ||
465 | if (func_name == "utf8_replace_slice" || func_name == "binary_replace_slice") { | |
466 | using Options = arrow::compute::ReplaceSliceOptions; | |
467 | ||
468 | return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["start"]), | |
469 | cpp11::as_cpp<int64_t>(options["stop"]), | |
470 | cpp11::as_cpp<std::string>(options["replacement"])); | |
471 | } | |
472 | ||
473 | if (func_name == "variance" || func_name == "stddev" || func_name == "hash_variance" || | |
474 | func_name == "hash_stddev") { | |
475 | using Options = arrow::compute::VarianceOptions; | |
476 | auto out = std::make_shared<Options>(); | |
477 | out->ddof = cpp11::as_cpp<int64_t>(options["ddof"]); | |
478 | if (!Rf_isNull(options["min_count"])) { | |
479 | out->min_count = cpp11::as_cpp<int64_t>(options["min_count"]); | |
480 | } | |
481 | if (!Rf_isNull(options["skip_nulls"])) { | |
482 | out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]); | |
483 | } | |
484 | return out; | |
485 | } | |
486 | ||
487 | if (func_name == "mode") { | |
488 | using Options = arrow::compute::ModeOptions; | |
489 | auto out = std::make_shared<Options>(Options::Defaults()); | |
490 | if (!Rf_isNull(options["n"])) { | |
491 | out->n = cpp11::as_cpp<int64_t>(options["n"]); | |
492 | } | |
493 | if (!Rf_isNull(options["min_count"])) { | |
494 | out->min_count = cpp11::as_cpp<uint32_t>(options["min_count"]); | |
495 | } | |
496 | if (!Rf_isNull(options["skip_nulls"])) { | |
497 | out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]); | |
498 | } | |
499 | return out; | |
500 | } | |
501 | ||
502 | if (func_name == "partition_nth_indices") { | |
503 | using Options = arrow::compute::PartitionNthOptions; | |
504 | return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["pivot"])); | |
505 | } | |
506 | ||
507 | if (func_name == "round") { | |
508 | using Options = arrow::compute::RoundOptions; | |
509 | auto out = std::make_shared<Options>(Options::Defaults()); | |
510 | if (!Rf_isNull(options["ndigits"])) { | |
511 | out->ndigits = cpp11::as_cpp<int64_t>(options["ndigits"]); | |
512 | } | |
513 | SEXP round_mode = options["round_mode"]; | |
514 | if (!Rf_isNull(round_mode)) { | |
515 | out->round_mode = cpp11::as_cpp<enum arrow::compute::RoundMode>(round_mode); | |
516 | } | |
517 | return out; | |
518 | } | |
519 | ||
520 | if (func_name == "round_to_multiple") { | |
521 | using Options = arrow::compute::RoundToMultipleOptions; | |
522 | auto out = std::make_shared<Options>(Options::Defaults()); | |
523 | if (!Rf_isNull(options["multiple"])) { | |
524 | out->multiple = std::make_shared<arrow::DoubleScalar>( | |
525 | cpp11::as_cpp<double>(options["multiple"])); | |
526 | } | |
527 | SEXP round_mode = options["round_mode"]; | |
528 | if (!Rf_isNull(round_mode)) { | |
529 | out->round_mode = cpp11::as_cpp<enum arrow::compute::RoundMode>(round_mode); | |
530 | } | |
531 | return out; | |
532 | } | |
533 | ||
534 | return nullptr; | |
535 | } | |
536 | ||
537 | std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options) { | |
538 | using Options = arrow::compute::CastOptions; | |
539 | auto out = std::make_shared<Options>(true); | |
540 | SEXP to_type = options["to_type"]; | |
541 | if (!Rf_isNull(to_type) && cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type)) { | |
542 | out->to_type = cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type); | |
543 | } | |
544 | ||
545 | SEXP allow_float_truncate = options["allow_float_truncate"]; | |
546 | if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp<bool>(allow_float_truncate)) { | |
547 | out->allow_float_truncate = cpp11::as_cpp<bool>(allow_float_truncate); | |
548 | } | |
549 | ||
550 | SEXP allow_time_truncate = options["allow_time_truncate"]; | |
551 | if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp<bool>(allow_time_truncate)) { | |
552 | out->allow_time_truncate = cpp11::as_cpp<bool>(allow_time_truncate); | |
553 | } | |
554 | ||
555 | SEXP allow_int_overflow = options["allow_int_overflow"]; | |
556 | if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp<bool>(allow_int_overflow)) { | |
557 | out->allow_int_overflow = cpp11::as_cpp<bool>(allow_int_overflow); | |
558 | } | |
559 | return out; | |
560 | } | |
561 | ||
562 | // [[arrow::export]] | |
563 | SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list options) { | |
564 | auto opts = make_compute_options(func_name, options); | |
565 | auto datum_args = arrow::r::from_r_list<arrow::Datum>(args); | |
566 | auto out = ValueOrStop( | |
567 | arrow::compute::CallFunction(func_name, datum_args, opts.get(), gc_context())); | |
568 | return from_datum(std::move(out)); | |
569 | } | |
570 | ||
571 | // [[arrow::export]] | |
572 | std::vector<std::string> compute__GetFunctionNames() { | |
573 | return arrow::compute::GetFunctionRegistry()->GetFunctionNames(); | |
574 | } | |
575 | ||
576 | #endif |