1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include <arrow/api.h>
19 #include <arrow/compute/api.h>
20 #include <arrow/compute/exec/exec_plan.h>
21 #include <arrow/compute/exec/expression.h>
22 #include <arrow/compute/exec/options.h>
23 #include <arrow/util/async_generator.h>
24 #include <arrow/util/future.h>
30 // Demonstrate registering an Arrow compute function outside of the Arrow source tree
32 namespace cp
= ::arrow::compute
;
34 #define ABORT_ON_FAILURE(expr) \
36 arrow::Status status_ = (expr); \
37 if (!status_.ok()) { \
38 std::cerr << status_.message() << std::endl; \
43 class ExampleFunctionOptionsType
: public cp::FunctionOptionsType
{
44 const char* type_name() const override
{ return "ExampleFunctionOptionsType"; }
45 std::string
Stringify(const cp::FunctionOptions
&) const override
{
46 return "ExampleFunctionOptionsType";
48 bool Compare(const cp::FunctionOptions
&, const cp::FunctionOptions
&) const override
{
51 std::unique_ptr
<cp::FunctionOptions
> Copy(const cp::FunctionOptions
&) const override
;
52 // optional: support for serialization
53 // Result<std::shared_ptr<Buffer>> Serialize(const FunctionOptions&) const override;
54 // Result<std::unique_ptr<FunctionOptions>> Deserialize(const Buffer&) const override;
57 cp::FunctionOptionsType
* GetExampleFunctionOptionsType() {
58 static ExampleFunctionOptionsType options_type
;
62 class ExampleFunctionOptions
: public cp::FunctionOptions
{
64 ExampleFunctionOptions() : cp::FunctionOptions(GetExampleFunctionOptionsType()) {}
67 std::unique_ptr
<cp::FunctionOptions
> ExampleFunctionOptionsType::Copy(
68 const cp::FunctionOptions
&) const {
69 return std::unique_ptr
<cp::FunctionOptions
>(new ExampleFunctionOptions());
72 arrow::Status
ExampleFunctionImpl(cp::KernelContext
* ctx
, const cp::ExecBatch
& batch
,
74 *out
->mutable_array() = *batch
[0].array();
75 return arrow::Status::OK();
78 class ExampleNodeOptions
: public cp::ExecNodeOptions
{};
80 // a basic ExecNode which ignores all input batches
81 class ExampleNode
: public cp::ExecNode
{
83 ExampleNode(ExecNode
* input
, const ExampleNodeOptions
&)
84 : ExecNode(/*plan=*/input
->plan(), /*inputs=*/{input
},
85 /*input_labels=*/{"ignored"},
86 /*output_schema=*/input
->output_schema(), /*num_outputs=*/1) {}
88 const char* kind_name() const override
{ return "ExampleNode"; }
90 arrow::Status
StartProducing() override
{
91 outputs_
[0]->InputFinished(this, 0);
92 return arrow::Status::OK();
95 void ResumeProducing(ExecNode
* output
) override
{}
96 void PauseProducing(ExecNode
* output
) override
{}
98 void StopProducing(ExecNode
* output
) override
{ inputs_
[0]->StopProducing(this); }
99 void StopProducing() override
{ inputs_
[0]->StopProducing(); }
101 void InputReceived(ExecNode
* input
, cp::ExecBatch batch
) override
{}
102 void ErrorReceived(ExecNode
* input
, arrow::Status error
) override
{}
103 void InputFinished(ExecNode
* input
, int total_batches
) override
{}
105 arrow::Future
<> finished() override
{ return inputs_
[0]->finished(); }
108 arrow::Result
<cp::ExecNode
*> ExampleExecNodeFactory(cp::ExecPlan
* plan
,
109 std::vector
<cp::ExecNode
*> inputs
,
110 const cp::ExecNodeOptions
& options
) {
111 const auto& example_options
=
112 arrow::internal::checked_cast
<const ExampleNodeOptions
&>(options
);
114 return plan
->EmplaceNode
<ExampleNode
>(inputs
[0], example_options
);
117 const cp::FunctionDoc func_doc
{
118 "Example function to demonstrate registering an out-of-tree function",
121 "ExampleFunctionOptions"};
123 int main(int argc
, char** argv
) {
124 const std::string name
= "compute_register_example";
125 auto func
= std::make_shared
<cp::ScalarFunction
>(name
, cp::Arity::Unary(), &func_doc
);
126 ABORT_ON_FAILURE(func
->AddKernel({cp::InputType::Array(arrow::int64())}, arrow::int64(),
127 ExampleFunctionImpl
));
129 auto registry
= cp::GetFunctionRegistry();
130 ABORT_ON_FAILURE(registry
->AddFunction(std::move(func
)));
132 arrow::Int64Builder
builder(arrow::default_memory_pool());
133 std::shared_ptr
<arrow::Array
> arr
;
134 ABORT_ON_FAILURE(builder
.Append(42));
135 ABORT_ON_FAILURE(builder
.Finish(&arr
));
136 auto options
= std::make_shared
<ExampleFunctionOptions
>();
137 auto maybe_result
= cp::CallFunction(name
, {arr
}, options
.get());
138 ABORT_ON_FAILURE(maybe_result
.status());
140 std::cout
<< maybe_result
->make_array()->ToString() << std::endl
;
142 // Expression serialization will raise NotImplemented if an expression includes
143 // FunctionOptions for which serialization is not supported.
144 auto expr
= cp::call(name
, {}, options
);
145 auto maybe_serialized
= cp::Serialize(expr
);
146 std::cerr
<< maybe_serialized
.status().ToString() << std::endl
;
148 auto exec_registry
= cp::default_exec_factory_registry();
150 exec_registry
->AddFactory("compute_register_example", ExampleExecNodeFactory
));
152 auto maybe_plan
= cp::ExecPlan::Make();
153 ABORT_ON_FAILURE(maybe_plan
.status());
154 auto plan
= maybe_plan
.ValueOrDie();
156 arrow::AsyncGenerator
<arrow::util::optional
<cp::ExecBatch
>> source_gen
, sink_gen
;
158 cp::Declaration::Sequence(
160 {"source", cp::SourceNodeOptions
{arrow::schema({}), source_gen
}},
161 {"compute_register_example", ExampleNodeOptions
{}},
162 {"sink", cp::SinkNodeOptions
{&sink_gen
}},
164 .AddToPlan(plan
.get())