]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "./arrow_types.h" | |
19 | ||
20 | #if defined(ARROW_R_WITH_ARROW) | |
21 | ||
22 | #include <arrow/compute/api.h> | |
23 | #include <arrow/compute/exec/exec_plan.h> | |
24 | #include <arrow/compute/exec/expression.h> | |
25 | #include <arrow/compute/exec/options.h> | |
26 | #include <arrow/table.h> | |
27 | #include <arrow/util/async_generator.h> | |
28 | #include <arrow/util/future.h> | |
29 | #include <arrow/util/optional.h> | |
30 | #include <arrow/util/thread_pool.h> | |
31 | ||
32 | #include <iostream> | |
33 | ||
34 | namespace compute = ::arrow::compute; | |
35 | ||
36 | std::shared_ptr<compute::FunctionOptions> make_compute_options(std::string func_name, | |
37 | cpp11::list options); | |
38 | ||
39 | // [[arrow::export]] | |
40 | std::shared_ptr<compute::ExecPlan> ExecPlan_create(bool use_threads) { | |
41 | static compute::ExecContext threaded_context{gc_memory_pool(), | |
42 | arrow::internal::GetCpuThreadPool()}; | |
43 | auto plan = ValueOrStop( | |
44 | compute::ExecPlan::Make(use_threads ? &threaded_context : gc_context())); | |
45 | return plan; | |
46 | } | |
47 | ||
48 | std::shared_ptr<compute::ExecNode> MakeExecNodeOrStop( | |
49 | const std::string& factory_name, compute::ExecPlan* plan, | |
50 | std::vector<compute::ExecNode*> inputs, const compute::ExecNodeOptions& options) { | |
51 | return std::shared_ptr<compute::ExecNode>( | |
52 | ValueOrStop(compute::MakeExecNode(factory_name, plan, std::move(inputs), options)), | |
53 | [](...) { | |
54 | // empty destructor: ExecNode lifetime is managed by an ExecPlan | |
55 | }); | |
56 | } | |
57 | ||
58 | // [[arrow::export]] | |
59 | std::shared_ptr<arrow::RecordBatchReader> ExecPlan_run( | |
60 | const std::shared_ptr<compute::ExecPlan>& plan, | |
61 | const std::shared_ptr<compute::ExecNode>& final_node, cpp11::list sort_options, | |
62 | int64_t head = -1) { | |
63 | // For now, don't require R to construct SinkNodes. | |
64 | // Instead, just pass the node we should collect as an argument. | |
65 | arrow::AsyncGenerator<arrow::util::optional<compute::ExecBatch>> sink_gen; | |
66 | ||
67 | // Sorting uses a different sink node; there is no general sort yet | |
68 | if (sort_options.size() > 0) { | |
69 | if (head >= 0) { | |
70 | // Use the SelectK node to take only what we need | |
71 | MakeExecNodeOrStop( | |
72 | "select_k_sink", plan.get(), {final_node.get()}, | |
73 | compute::SelectKSinkNodeOptions{ | |
74 | arrow::compute::SelectKOptions( | |
75 | head, std::dynamic_pointer_cast<compute::SortOptions>( | |
76 | make_compute_options("sort_indices", sort_options)) | |
77 | ->sort_keys), | |
78 | &sink_gen}); | |
79 | } else { | |
80 | MakeExecNodeOrStop("order_by_sink", plan.get(), {final_node.get()}, | |
81 | compute::OrderBySinkNodeOptions{ | |
82 | *std::dynamic_pointer_cast<compute::SortOptions>( | |
83 | make_compute_options("sort_indices", sort_options)), | |
84 | &sink_gen}); | |
85 | } | |
86 | } else { | |
87 | MakeExecNodeOrStop("sink", plan.get(), {final_node.get()}, | |
88 | compute::SinkNodeOptions{&sink_gen}); | |
89 | } | |
90 | ||
91 | StopIfNotOk(plan->Validate()); | |
92 | StopIfNotOk(plan->StartProducing()); | |
93 | ||
94 | // If the generator is destroyed before being completely drained, inform plan | |
95 | std::shared_ptr<void> stop_producing{nullptr, [plan](...) { | |
96 | bool not_finished_yet = | |
97 | plan->finished().TryAddCallback([&plan] { | |
98 | return [plan](const arrow::Status&) {}; | |
99 | }); | |
100 | ||
101 | if (not_finished_yet) { | |
102 | plan->StopProducing(); | |
103 | } | |
104 | }}; | |
105 | ||
106 | return compute::MakeGeneratorReader( | |
107 | final_node->output_schema(), | |
108 | [stop_producing, plan, sink_gen] { return sink_gen(); }, gc_memory_pool()); | |
109 | } | |
110 | ||
111 | // [[arrow::export]] | |
112 | void ExecPlan_StopProducing(const std::shared_ptr<compute::ExecPlan>& plan) { | |
113 | plan->StopProducing(); | |
114 | } | |
115 | ||
116 | #if defined(ARROW_R_WITH_DATASET) | |
117 | ||
118 | #include <arrow/dataset/plan.h> | |
119 | #include <arrow/dataset/scanner.h> | |
120 | ||
121 | // [[dataset::export]] | |
122 | std::shared_ptr<arrow::Schema> ExecNode_output_schema( | |
123 | const std::shared_ptr<compute::ExecNode>& node) { | |
124 | return node->output_schema(); | |
125 | } | |
126 | ||
127 | // [[dataset::export]] | |
128 | std::shared_ptr<compute::ExecNode> ExecNode_Scan( | |
129 | const std::shared_ptr<compute::ExecPlan>& plan, | |
130 | const std::shared_ptr<arrow::dataset::Dataset>& dataset, | |
131 | const std::shared_ptr<compute::Expression>& filter, | |
132 | std::vector<std::string> materialized_field_names) { | |
133 | arrow::dataset::internal::Initialize(); | |
134 | ||
135 | // TODO: pass in FragmentScanOptions | |
136 | auto options = std::make_shared<arrow::dataset::ScanOptions>(); | |
137 | ||
138 | options->use_async = true; | |
139 | options->use_threads = arrow::r::GetBoolOption("arrow.use_threads", true); | |
140 | ||
141 | options->dataset_schema = dataset->schema(); | |
142 | ||
143 | // ScanNode needs the filter to do predicate pushdown and skip partitions | |
144 | options->filter = ValueOrStop(filter->Bind(*dataset->schema())); | |
145 | ||
146 | // ScanNode needs to know which fields to materialize (and which are unnecessary) | |
147 | std::vector<compute::Expression> exprs; | |
148 | for (const auto& name : materialized_field_names) { | |
149 | exprs.push_back(compute::field_ref(name)); | |
150 | } | |
151 | ||
152 | options->projection = | |
153 | ValueOrStop(call("make_struct", std::move(exprs), | |
154 | compute::MakeStructOptions{std::move(materialized_field_names)}) | |
155 | .Bind(*dataset->schema())); | |
156 | ||
157 | return MakeExecNodeOrStop("scan", plan.get(), {}, | |
158 | arrow::dataset::ScanNodeOptions{dataset, options}); | |
159 | } | |
160 | ||
161 | #endif | |
162 | ||
163 | // [[dataset::export]] | |
164 | std::shared_ptr<compute::ExecNode> ExecNode_Filter( | |
165 | const std::shared_ptr<compute::ExecNode>& input, | |
166 | const std::shared_ptr<compute::Expression>& filter) { | |
167 | return MakeExecNodeOrStop("filter", input->plan(), {input.get()}, | |
168 | compute::FilterNodeOptions{*filter}); | |
169 | } | |
170 | ||
171 | // [[dataset::export]] | |
172 | std::shared_ptr<compute::ExecNode> ExecNode_Project( | |
173 | const std::shared_ptr<compute::ExecNode>& input, | |
174 | const std::vector<std::shared_ptr<compute::Expression>>& exprs, | |
175 | std::vector<std::string> names) { | |
176 | // We have shared_ptrs of expressions but need the Expressions | |
177 | std::vector<compute::Expression> expressions; | |
178 | for (auto expr : exprs) { | |
179 | expressions.push_back(*expr); | |
180 | } | |
181 | return MakeExecNodeOrStop( | |
182 | "project", input->plan(), {input.get()}, | |
183 | compute::ProjectNodeOptions{std::move(expressions), std::move(names)}); | |
184 | } | |
185 | ||
186 | // [[dataset::export]] | |
187 | std::shared_ptr<compute::ExecNode> ExecNode_Aggregate( | |
188 | const std::shared_ptr<compute::ExecNode>& input, cpp11::list options, | |
189 | std::vector<std::string> target_names, std::vector<std::string> out_field_names, | |
190 | std::vector<std::string> key_names) { | |
191 | std::vector<arrow::compute::internal::Aggregate> aggregates; | |
192 | std::vector<std::shared_ptr<arrow::compute::FunctionOptions>> keep_alives; | |
193 | ||
194 | for (cpp11::list name_opts : options) { | |
195 | auto name = cpp11::as_cpp<std::string>(name_opts[0]); | |
196 | auto opts = make_compute_options(name, name_opts[1]); | |
197 | ||
198 | aggregates.push_back( | |
199 | arrow::compute::internal::Aggregate{std::move(name), opts.get()}); | |
200 | keep_alives.push_back(std::move(opts)); | |
201 | } | |
202 | ||
203 | std::vector<arrow::FieldRef> targets, keys; | |
204 | for (auto&& name : target_names) { | |
205 | targets.emplace_back(std::move(name)); | |
206 | } | |
207 | for (auto&& name : key_names) { | |
208 | keys.emplace_back(std::move(name)); | |
209 | } | |
210 | return MakeExecNodeOrStop( | |
211 | "aggregate", input->plan(), {input.get()}, | |
212 | compute::AggregateNodeOptions{std::move(aggregates), std::move(targets), | |
213 | std::move(out_field_names), std::move(keys)}); | |
214 | } | |
215 | ||
216 | // [[dataset::export]] | |
217 | std::shared_ptr<compute::ExecNode> ExecNode_Join( | |
218 | const std::shared_ptr<compute::ExecNode>& input, int type, | |
219 | const std::shared_ptr<compute::ExecNode>& right_data, | |
220 | std::vector<std::string> left_keys, std::vector<std::string> right_keys, | |
221 | std::vector<std::string> left_output, std::vector<std::string> right_output) { | |
222 | std::vector<arrow::FieldRef> left_refs, right_refs, left_out_refs, right_out_refs; | |
223 | for (auto&& name : left_keys) { | |
224 | left_refs.emplace_back(std::move(name)); | |
225 | } | |
226 | for (auto&& name : right_keys) { | |
227 | right_refs.emplace_back(std::move(name)); | |
228 | } | |
229 | for (auto&& name : left_output) { | |
230 | left_out_refs.emplace_back(std::move(name)); | |
231 | } | |
232 | if (type != 0 && type != 2) { | |
233 | // Don't include out_refs in semi/anti join | |
234 | for (auto&& name : right_output) { | |
235 | right_out_refs.emplace_back(std::move(name)); | |
236 | } | |
237 | } | |
238 | ||
239 | // TODO: we should be able to use this enum directly | |
240 | compute::JoinType join_type; | |
241 | if (type == 0) { | |
242 | join_type = compute::JoinType::LEFT_SEMI; | |
243 | } else if (type == 1) { | |
244 | // Not readily called from R bc dplyr::semi_join is LEFT_SEMI | |
245 | join_type = compute::JoinType::RIGHT_SEMI; | |
246 | } else if (type == 2) { | |
247 | join_type = compute::JoinType::LEFT_ANTI; | |
248 | } else if (type == 3) { | |
249 | // Not readily called from R bc dplyr::semi_join is LEFT_SEMI | |
250 | join_type = compute::JoinType::RIGHT_ANTI; | |
251 | } else if (type == 4) { | |
252 | join_type = compute::JoinType::INNER; | |
253 | } else if (type == 5) { | |
254 | join_type = compute::JoinType::LEFT_OUTER; | |
255 | } else if (type == 6) { | |
256 | join_type = compute::JoinType::RIGHT_OUTER; | |
257 | } else if (type == 7) { | |
258 | join_type = compute::JoinType::FULL_OUTER; | |
259 | } else { | |
260 | cpp11::stop("todo"); | |
261 | } | |
262 | ||
263 | return MakeExecNodeOrStop( | |
264 | "hashjoin", input->plan(), {input.get(), right_data.get()}, | |
265 | compute::HashJoinNodeOptions{join_type, std::move(left_refs), std::move(right_refs), | |
266 | std::move(left_out_refs), std::move(right_out_refs)}); | |
267 | } | |
268 | ||
269 | // [[arrow::export]] | |
270 | std::shared_ptr<compute::ExecNode> ExecNode_ReadFromRecordBatchReader( | |
271 | const std::shared_ptr<compute::ExecPlan>& plan, | |
272 | const std::shared_ptr<arrow::RecordBatchReader>& reader) { | |
273 | arrow::compute::SourceNodeOptions options{ | |
274 | /*output_schema=*/reader->schema(), | |
275 | /*generator=*/ValueOrStop( | |
276 | compute::MakeReaderGenerator(reader, arrow::internal::GetCpuThreadPool()))}; | |
277 | ||
278 | return MakeExecNodeOrStop("source", plan.get(), {}, options); | |
279 | } | |
280 | ||
281 | #endif |