1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
24 #include "arrow/compute/exec/options.h"
25 #include "arrow/compute/exec/schema_util.h"
26 #include "arrow/compute/exec/task_util.h"
27 #include "arrow/result.h"
28 #include "arrow/status.h"
29 #include "arrow/type.h"
34 class ARROW_EXPORT HashJoinSchema
{
36 Status
Init(JoinType join_type
, const Schema
& left_schema
,
37 const std::vector
<FieldRef
>& left_keys
, const Schema
& right_schema
,
38 const std::vector
<FieldRef
>& right_keys
,
39 const std::string
& left_field_name_prefix
,
40 const std::string
& right_field_name_prefix
);
42 Status
Init(JoinType join_type
, const Schema
& left_schema
,
43 const std::vector
<FieldRef
>& left_keys
,
44 const std::vector
<FieldRef
>& left_output
, const Schema
& right_schema
,
45 const std::vector
<FieldRef
>& right_keys
,
46 const std::vector
<FieldRef
>& right_output
,
47 const std::string
& left_field_name_prefix
,
48 const std::string
& right_field_name_prefix
);
50 static Status
ValidateSchemas(JoinType join_type
, const Schema
& left_schema
,
51 const std::vector
<FieldRef
>& left_keys
,
52 const std::vector
<FieldRef
>& left_output
,
53 const Schema
& right_schema
,
54 const std::vector
<FieldRef
>& right_keys
,
55 const std::vector
<FieldRef
>& right_output
,
56 const std::string
& left_field_name_prefix
,
57 const std::string
& right_field_name_prefix
);
59 std::shared_ptr
<Schema
> MakeOutputSchema(const std::string
& left_field_name_prefix
,
60 const std::string
& right_field_name_prefix
);
62 static int kMissingField() {
63 return SchemaProjectionMaps
<HashJoinProjection
>::kMissingField
;
66 SchemaProjectionMaps
<HashJoinProjection
> proj_maps
[2];
69 static bool IsTypeSupported(const DataType
& type
);
70 static Result
<std::vector
<FieldRef
>> VectorDiff(const Schema
& schema
,
71 const std::vector
<FieldRef
>& a
,
72 const std::vector
<FieldRef
>& b
);
77 using OutputBatchCallback
= std::function
<void(ExecBatch
)>;
78 using FinishedCallback
= std::function
<void(int64_t)>;
80 virtual ~HashJoinImpl() = default;
81 virtual Status
Init(ExecContext
* ctx
, JoinType join_type
, bool use_sync_execution
,
82 size_t num_threads
, HashJoinSchema
* schema_mgr
,
83 std::vector
<JoinKeyCmp
> key_cmp
,
84 OutputBatchCallback output_batch_callback
,
85 FinishedCallback finished_callback
,
86 TaskScheduler::ScheduleImpl schedule_task_callback
) = 0;
87 virtual Status
InputReceived(size_t thread_index
, int side
, ExecBatch batch
) = 0;
88 virtual Status
InputFinished(size_t thread_index
, int side
) = 0;
89 virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback
) = 0;
91 static Result
<std::unique_ptr
<HashJoinImpl
>> MakeBasic();
94 } // namespace compute