]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/compute/exec/exec_plan.h" | |
19 | ||
20 | #include <sstream> | |
21 | ||
22 | #include "arrow/compute/api_vector.h" | |
23 | #include "arrow/compute/exec.h" | |
24 | #include "arrow/compute/exec/expression.h" | |
25 | #include "arrow/compute/exec/options.h" | |
26 | #include "arrow/compute/exec/util.h" | |
27 | #include "arrow/datum.h" | |
28 | #include "arrow/result.h" | |
29 | #include "arrow/util/checked_cast.h" | |
30 | #include "arrow/util/future.h" | |
31 | #include "arrow/util/logging.h" | |
32 | ||
33 | namespace arrow { | |
34 | ||
35 | using internal::checked_cast; | |
36 | ||
37 | namespace compute { | |
38 | namespace { | |
39 | ||
40 | class ProjectNode : public MapNode { | |
41 | public: | |
42 | ProjectNode(ExecPlan* plan, std::vector<ExecNode*> inputs, | |
43 | std::shared_ptr<Schema> output_schema, std::vector<Expression> exprs, | |
44 | bool async_mode) | |
45 | : MapNode(plan, std::move(inputs), std::move(output_schema), async_mode), | |
46 | exprs_(std::move(exprs)) {} | |
47 | ||
48 | static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs, | |
49 | const ExecNodeOptions& options) { | |
50 | RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "ProjectNode")); | |
51 | ||
52 | const auto& project_options = checked_cast<const ProjectNodeOptions&>(options); | |
53 | auto exprs = project_options.expressions; | |
54 | auto names = project_options.names; | |
55 | ||
56 | if (names.size() == 0) { | |
57 | names.resize(exprs.size()); | |
58 | for (size_t i = 0; i < exprs.size(); ++i) { | |
59 | names[i] = exprs[i].ToString(); | |
60 | } | |
61 | } | |
62 | ||
63 | FieldVector fields(exprs.size()); | |
64 | int i = 0; | |
65 | for (auto& expr : exprs) { | |
66 | if (!expr.IsBound()) { | |
67 | ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*inputs[0]->output_schema())); | |
68 | } | |
69 | fields[i] = field(std::move(names[i]), expr.type()); | |
70 | ++i; | |
71 | } | |
72 | return plan->EmplaceNode<ProjectNode>(plan, std::move(inputs), | |
73 | schema(std::move(fields)), std::move(exprs), | |
74 | project_options.async_mode); | |
75 | } | |
76 | ||
77 | const char* kind_name() const override { return "ProjectNode"; } | |
78 | ||
79 | Result<ExecBatch> DoProject(const ExecBatch& target) { | |
80 | std::vector<Datum> values{exprs_.size()}; | |
81 | for (size_t i = 0; i < exprs_.size(); ++i) { | |
82 | ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, | |
83 | SimplifyWithGuarantee(exprs_[i], target.guarantee)); | |
84 | ||
85 | ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target, | |
86 | plan()->exec_context())); | |
87 | } | |
88 | return ExecBatch{std::move(values), target.length}; | |
89 | } | |
90 | ||
91 | void InputReceived(ExecNode* input, ExecBatch batch) override { | |
92 | DCHECK_EQ(input, inputs_[0]); | |
93 | auto func = [this](ExecBatch batch) { return DoProject(std::move(batch)); }; | |
94 | this->SubmitTask(std::move(func), std::move(batch)); | |
95 | } | |
96 | ||
97 | protected: | |
98 | std::string ToStringExtra() const override { | |
99 | std::stringstream ss; | |
100 | ss << "projection=["; | |
101 | for (int i = 0; static_cast<size_t>(i) < exprs_.size(); i++) { | |
102 | if (i > 0) ss << ", "; | |
103 | auto repr = exprs_[i].ToString(); | |
104 | if (repr != output_schema_->field(i)->name()) { | |
105 | ss << '"' << output_schema_->field(i)->name() << "\": "; | |
106 | } | |
107 | ss << repr; | |
108 | } | |
109 | ss << ']'; | |
110 | return ss.str(); | |
111 | } | |
112 | ||
113 | private: | |
114 | std::vector<Expression> exprs_; | |
115 | }; | |
116 | ||
117 | } // namespace | |
118 | ||
119 | namespace internal { | |
120 | ||
121 | void RegisterProjectNode(ExecFactoryRegistry* registry) { | |
122 | DCHECK_OK(registry->AddFactory("project", ProjectNode::Make)); | |
123 | } | |
124 | ||
125 | } // namespace internal | |
126 | } // namespace compute | |
127 | } // namespace arrow |