1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "gandiva/llvm_generator.h"
27 #include "gandiva/bitmap_accumulator.h"
28 #include "gandiva/decimal_ir.h"
29 #include "gandiva/dex.h"
30 #include "gandiva/expr_decomposer.h"
31 #include "gandiva/expression.h"
32 #include "gandiva/lvalue.h"
36 #define ADD_TRACE(...) \
37 if (enable_ir_traces_) { \
38 AddTrace(__VA_ARGS__); \
41 LLVMGenerator::LLVMGenerator() : enable_ir_traces_(false) {}
43 Status
LLVMGenerator::Make(std::shared_ptr
<Configuration
> config
,
44 std::unique_ptr
<LLVMGenerator
>* llvm_generator
) {
45 std::unique_ptr
<LLVMGenerator
> llvmgen_obj(new LLVMGenerator());
47 ARROW_RETURN_NOT_OK(Engine::Make(config
, &(llvmgen_obj
->engine_
)));
48 *llvm_generator
= std::move(llvmgen_obj
);
53 Status
LLVMGenerator::Add(const ExpressionPtr expr
, const FieldDescriptorPtr output
) {
54 int idx
= static_cast<int>(compiled_exprs_
.size());
55 // decompose the expression to separate out value and validities.
56 ExprDecomposer
decomposer(function_registry_
, annotator_
);
57 ValueValidityPairPtr value_validity
;
58 ARROW_RETURN_NOT_OK(decomposer
.Decompose(*expr
->root(), &value_validity
));
59 // Generate the IR function for the decomposed expression.
60 std::unique_ptr
<CompiledExpr
> compiled_expr(new CompiledExpr(value_validity
, output
));
61 llvm::Function
* ir_function
= nullptr;
62 ARROW_RETURN_NOT_OK(CodeGenExprValue(value_validity
->value_expr(),
63 annotator_
.buffer_count(), output
, idx
,
64 &ir_function
, selection_vector_mode_
));
65 compiled_expr
->SetIRFunction(selection_vector_mode_
, ir_function
);
67 compiled_exprs_
.push_back(std::move(compiled_expr
));
71 /// Build and optimise module for projection expression.
72 Status
LLVMGenerator::Build(const ExpressionVector
& exprs
, SelectionVector::Mode mode
) {
73 selection_vector_mode_
= mode
;
74 for (auto& expr
: exprs
) {
75 auto output
= annotator_
.AddOutputFieldDescriptor(expr
->result());
76 ARROW_RETURN_NOT_OK(Add(expr
, output
));
79 // Compile and inject into the process' memory the generated function.
80 ARROW_RETURN_NOT_OK(engine_
->FinalizeModule());
82 // setup the jit functions for each expression.
83 for (auto& compiled_expr
: compiled_exprs_
) {
84 auto ir_fn
= compiled_expr
->GetIRFunction(mode
);
85 auto jit_fn
= reinterpret_cast<EvalFunc
>(engine_
->CompiledFunction(ir_fn
));
86 compiled_expr
->SetJITFunction(selection_vector_mode_
, jit_fn
);
92 /// Execute the compiled module against the provided vectors.
93 Status
LLVMGenerator::Execute(const arrow::RecordBatch
& record_batch
,
94 const ArrayDataVector
& output_vector
) {
95 return Execute(record_batch
, nullptr, output_vector
);
98 /// Execute the compiled module against the provided vectors based on the type of
100 Status
LLVMGenerator::Execute(const arrow::RecordBatch
& record_batch
,
101 const SelectionVector
* selection_vector
,
102 const ArrayDataVector
& output_vector
) {
103 DCHECK_GT(record_batch
.num_rows(), 0);
105 auto eval_batch
= annotator_
.PrepareEvalBatch(record_batch
, output_vector
);
106 DCHECK_GT(eval_batch
->GetNumBuffers(), 0);
108 auto mode
= SelectionVector::MODE_NONE
;
109 if (selection_vector
!= nullptr) {
110 mode
= selection_vector
->GetMode();
112 if (mode
!= selection_vector_mode_
) {
113 return Status::Invalid("llvm expression built for selection vector mode ",
114 selection_vector_mode_
, " received vector with mode ", mode
);
117 for (auto& compiled_expr
: compiled_exprs_
) {
118 // generate data/offset vectors.
119 const uint8_t* selection_buffer
= nullptr;
120 auto num_output_rows
= record_batch
.num_rows();
121 if (selection_vector
!= nullptr) {
122 selection_buffer
= selection_vector
->GetBuffer().data();
123 num_output_rows
= selection_vector
->GetNumSlots();
126 EvalFunc jit_function
= compiled_expr
->GetJITFunction(mode
);
127 jit_function(eval_batch
->GetBufferArray(), eval_batch
->GetBufferOffsetArray(),
128 eval_batch
->GetLocalBitMapArray(), selection_buffer
,
129 (int64_t)eval_batch
->GetExecutionContext(), num_output_rows
);
131 // check for execution errors
133 eval_batch
->GetExecutionContext()->has_error(),
134 Status::ExecutionError(eval_batch
->GetExecutionContext()->get_error()));
136 // generate validity vectors.
137 ComputeBitMapsForExpr(*compiled_expr
, *eval_batch
, selection_vector
);
143 llvm::Value
* LLVMGenerator::LoadVectorAtIndex(llvm::Value
* arg_addrs
, int idx
,
144 const std::string
& name
) {
145 auto* idx_val
= types()->i32_constant(idx
);
146 auto* offset
= CreateGEP(ir_builder(), arg_addrs
, idx_val
, name
+ "_mem_addr");
147 return CreateLoad(ir_builder(), offset
, name
+ "_mem");
150 /// Get reference to validity array at specified index in the args list.
151 llvm::Value
* LLVMGenerator::GetValidityReference(llvm::Value
* arg_addrs
, int idx
,
153 const std::string
& name
= field
->name();
154 llvm::Value
* load
= LoadVectorAtIndex(arg_addrs
, idx
, name
);
155 return ir_builder()->CreateIntToPtr(load
, types()->i64_ptr_type(), name
+ "_varray");
158 /// Get reference to data array at specified index in the args list.
159 llvm::Value
* LLVMGenerator::GetDataBufferPtrReference(llvm::Value
* arg_addrs
, int idx
,
161 const std::string
& name
= field
->name();
162 llvm::Value
* load
= LoadVectorAtIndex(arg_addrs
, idx
, name
);
163 return ir_builder()->CreateIntToPtr(load
, types()->i8_ptr_type(), name
+ "_buf_ptr");
166 /// Get reference to data array at specified index in the args list.
167 llvm::Value
* LLVMGenerator::GetDataReference(llvm::Value
* arg_addrs
, int idx
,
169 const std::string
& name
= field
->name();
170 llvm::Value
* load
= LoadVectorAtIndex(arg_addrs
, idx
, name
);
171 llvm::Type
* base_type
= types()->DataVecType(field
->type());
173 if (base_type
->isPointerTy()) {
174 ret
= ir_builder()->CreateIntToPtr(load
, base_type
, name
+ "_darray");
176 llvm::Type
* pointer_type
= types()->ptr_type(base_type
);
177 ret
= ir_builder()->CreateIntToPtr(load
, pointer_type
, name
+ "_darray");
182 /// Get reference to offsets array at specified index in the args list.
183 llvm::Value
* LLVMGenerator::GetOffsetsReference(llvm::Value
* arg_addrs
, int idx
,
185 const std::string
& name
= field
->name();
186 llvm::Value
* load
= LoadVectorAtIndex(arg_addrs
, idx
, name
);
187 return ir_builder()->CreateIntToPtr(load
, types()->i32_ptr_type(), name
+ "_oarray");
190 /// Get reference to local bitmap array at specified index in the args list.
191 llvm::Value
* LLVMGenerator::GetLocalBitMapReference(llvm::Value
* arg_bitmaps
, int idx
) {
192 llvm::Value
* load
= LoadVectorAtIndex(arg_bitmaps
, idx
, "");
193 return ir_builder()->CreateIntToPtr(load
, types()->i64_ptr_type(),
194 std::to_string(idx
) + "_lbmap");
197 /// \brief Generate code for one expression.
199 // Sample IR code for "c1:int + c2:int"
201 // The C-code equivalent is :
202 // ------------------------------
203 // int expr_0(int64_t *addrs, int64_t *local_bitmaps,
204 // int64_t execution_context_ptr, int64_t nrecords) {
205 // int *outVec = (int *) addrs[5];
206 // int *c0Vec = (int *) addrs[1];
207 // int *c1Vec = (int *) addrs[3];
208 // for (int loop_var = 0; loop_var < nrecords; ++loop_var) {
209 // int c0 = c0Vec[loop_var];
210 // int c1 = c1Vec[loop_var];
211 // int out = c0 + c1;
212 // outVec[loop_var] = out;
219 // define i32 @expr_0(i64* %args, i64* %local_bitmaps, i64 %execution_context_ptr, , i64
220 // %nrecords) { entry:
221 // %outmemAddr = getelementptr i64, i64* %args, i32 5
222 // %outmem = load i64, i64* %outmemAddr
223 // %outVec = inttoptr i64 %outmem to i32*
224 // %c0memAddr = getelementptr i64, i64* %args, i32 1
225 // %c0mem = load i64, i64* %c0memAddr
226 // %c0Vec = inttoptr i64 %c0mem to i32*
227 // %c1memAddr = getelementptr i64, i64* %args, i32 3
228 // %c1mem = load i64, i64* %c1memAddr
229 // %c1Vec = inttoptr i64 %c1mem to i32*
231 // loop: ; preds = %loop, %entry
232 // %loop_var = phi i64 [ 0, %entry ], [ %"loop_var+1", %loop ]
233 // %"loop_var+1" = add i64 %loop_var, 1
234 // %0 = getelementptr i32, i32* %c0Vec, i32 %loop_var
235 // %c0 = load i32, i32* %0
236 // %1 = getelementptr i32, i32* %c1Vec, i32 %loop_var
237 // %c1 = load i32, i32* %1
238 // %add_int_int = call i32 @add_int_int(i32 %c0, i32 %c1)
239 // %2 = getelementptr i32, i32* %outVec, i32 %loop_var
240 // store i32 %add_int_int, i32* %2
241 // %"loop_var < nrec" = icmp slt i64 %"loop_var+1", %nrecords
242 // br i1 %"loop_var < nrec", label %loop, label %exit
243 // exit: ; preds = %loop
246 Status
LLVMGenerator::CodeGenExprValue(DexPtr value_expr
, int buffer_count
,
247 FieldDescriptorPtr output
, int suffix_idx
,
249 SelectionVector::Mode selection_vector_mode
) {
250 llvm::IRBuilder
<>* builder
= ir_builder();
251 // Create fn prototype :
252 // int expr_1 (long **addrs, long *offsets, long **bitmaps,
253 // long *context_ptr, long nrec)
254 std::vector
<llvm::Type
*> arguments
;
255 arguments
.push_back(types()->i64_ptr_type()); // addrs
256 arguments
.push_back(types()->i64_ptr_type()); // offsets
257 arguments
.push_back(types()->i64_ptr_type()); // bitmaps
258 switch (selection_vector_mode
) {
259 case SelectionVector::MODE_NONE
:
260 case SelectionVector::MODE_UINT16
:
261 arguments
.push_back(types()->ptr_type(types()->i16_type()));
263 case SelectionVector::MODE_UINT32
:
264 arguments
.push_back(types()->i32_ptr_type());
266 case SelectionVector::MODE_UINT64
:
267 arguments
.push_back(types()->i64_ptr_type());
269 arguments
.push_back(types()->i64_type()); // ctx_ptr
270 arguments
.push_back(types()->i64_type()); // nrec
271 llvm::FunctionType
* prototype
=
272 llvm::FunctionType::get(types()->i32_type(), arguments
, false /*isVarArg*/);
275 std::string func_name
= "expr_" + std::to_string(suffix_idx
) + "_" +
276 std::to_string(static_cast<int>(selection_vector_mode
));
277 engine_
->AddFunctionToCompile(func_name
);
278 *fn
= llvm::Function::Create(prototype
, llvm::GlobalValue::ExternalLinkage
, func_name
,
280 ARROW_RETURN_IF((*fn
== nullptr), Status::CodeGenError("Error creating function."));
282 // Name the arguments
283 llvm::Function::arg_iterator args
= (*fn
)->arg_begin();
284 llvm::Value
* arg_addrs
= &*args
;
285 arg_addrs
->setName("inputs_addr");
287 llvm::Value
* arg_addr_offsets
= &*args
;
288 arg_addr_offsets
->setName("inputs_addr_offsets");
290 llvm::Value
* arg_local_bitmaps
= &*args
;
291 arg_local_bitmaps
->setName("local_bitmaps");
293 llvm::Value
* arg_selection_vector
= &*args
;
294 arg_selection_vector
->setName("selection_vector");
296 llvm::Value
* arg_context_ptr
= &*args
;
297 arg_context_ptr
->setName("context_ptr");
299 llvm::Value
* arg_nrecords
= &*args
;
300 arg_nrecords
->setName("nrecords");
302 llvm::BasicBlock
* loop_entry
= llvm::BasicBlock::Create(*context(), "entry", *fn
);
303 llvm::BasicBlock
* loop_body
= llvm::BasicBlock::Create(*context(), "loop", *fn
);
304 llvm::BasicBlock
* loop_exit
= llvm::BasicBlock::Create(*context(), "exit", *fn
);
306 // Add reference to output vector (in entry block)
307 builder
->SetInsertPoint(loop_entry
);
308 llvm::Value
* output_ref
=
309 GetDataReference(arg_addrs
, output
->data_idx(), output
->field());
310 llvm::Value
* output_buffer_ptr_ref
= GetDataBufferPtrReference(
311 arg_addrs
, output
->data_buffer_ptr_idx(), output
->field());
312 llvm::Value
* output_offset_ref
=
313 GetOffsetsReference(arg_addrs
, output
->offsets_idx(), output
->field());
315 std::vector
<llvm::Value
*> slice_offsets
;
316 for (int idx
= 0; idx
< buffer_count
; idx
++) {
317 auto offsetAddr
= CreateGEP(builder
, arg_addr_offsets
, types()->i32_constant(idx
));
318 auto offset
= CreateLoad(builder
, offsetAddr
);
319 slice_offsets
.push_back(offset
);
323 builder
->SetInsertPoint(loop_body
);
325 // define loop_var : start with 0, +1 after each iter
326 llvm::PHINode
* loop_var
= builder
->CreatePHI(types()->i64_type(), 2, "loop_var");
328 llvm::Value
* position_var
= loop_var
;
329 if (selection_vector_mode
!= SelectionVector::MODE_NONE
) {
330 position_var
= builder
->CreateIntCast(
331 CreateLoad(builder
, CreateGEP(builder
, arg_selection_vector
, loop_var
),
332 "uncasted_position_var"),
333 types()->i64_type(), true, "position_var");
336 // The visitor can add code to both the entry/loop blocks.
337 Visitor
visitor(this, *fn
, loop_entry
, arg_addrs
, arg_local_bitmaps
, slice_offsets
,
338 arg_context_ptr
, position_var
);
339 value_expr
->Accept(visitor
);
340 LValuePtr output_value
= visitor
.result();
342 // The "current" block may have changed due to code generation in the visitor.
343 llvm::BasicBlock
* loop_body_tail
= builder
->GetInsertBlock();
345 // add jump to "loop block" at the end of the "setup block".
346 builder
->SetInsertPoint(loop_entry
);
347 builder
->CreateBr(loop_body
);
349 // save the value in the output vector.
350 builder
->SetInsertPoint(loop_body_tail
);
352 auto output_type_id
= output
->Type()->id();
353 if (output_type_id
== arrow::Type::BOOL
) {
354 SetPackedBitValue(output_ref
, loop_var
, output_value
->data());
355 } else if (arrow::is_primitive(output_type_id
) ||
356 output_type_id
== arrow::Type::DECIMAL
) {
357 llvm::Value
* slot_offset
= CreateGEP(builder
, output_ref
, loop_var
);
358 builder
->CreateStore(output_value
->data(), slot_offset
);
359 } else if (arrow::is_binary_like(output_type_id
)) {
360 // Var-len output. Make a function call to populate the data.
361 // if there is an error, the fn sets it in the context. And, will be returned at the
362 // end of this row batch.
363 AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(),
364 {arg_context_ptr
, output_buffer_ptr_ref
, output_offset_ref
, loop_var
,
365 output_value
->data(), output_value
->length()});
367 return Status::NotImplemented("output type ", output
->Type()->ToString(),
370 ADD_TRACE("saving result " + output
->Name() + " value %T", output_value
->data());
372 if (visitor
.has_arena_allocs()) {
373 // Reset allocations to avoid excessive memory usage. Once the result is copied to
374 // the output vector (store instruction above), any memory allocations in this
375 // iteration of the loop are no longer needed.
376 std::vector
<llvm::Value
*> reset_args
;
377 reset_args
.push_back(arg_context_ptr
);
378 AddFunctionCall("gdv_fn_context_arena_reset", types()->void_type(), reset_args
);
382 loop_var
->addIncoming(types()->i64_constant(0), loop_entry
);
383 llvm::Value
* loop_update
=
384 builder
->CreateAdd(loop_var
, types()->i64_constant(1), "loop_var+1");
385 loop_var
->addIncoming(loop_update
, loop_body_tail
);
387 llvm::Value
* loop_var_check
=
388 builder
->CreateICmpSLT(loop_update
, arg_nrecords
, "loop_var < nrec");
389 builder
->CreateCondBr(loop_var_check
, loop_body
, loop_exit
);
392 builder
->SetInsertPoint(loop_exit
);
393 builder
->CreateRet(types()->i32_constant(0));
397 /// Return value of a bit in bitMap.
398 llvm::Value
* LLVMGenerator::GetPackedBitValue(llvm::Value
* bitmap
,
399 llvm::Value
* position
) {
400 ADD_TRACE("fetch bit at position %T", position
);
402 llvm::Value
* bitmap8
= ir_builder()->CreateBitCast(
403 bitmap
, types()->ptr_type(types()->i8_type()), "bitMapCast");
404 return AddFunctionCall("bitMapGetBit", types()->i1_type(), {bitmap8
, position
});
407 /// Set the value of a bit in bitMap.
408 void LLVMGenerator::SetPackedBitValue(llvm::Value
* bitmap
, llvm::Value
* position
,
409 llvm::Value
* value
) {
410 ADD_TRACE("set bit at position %T", position
);
411 ADD_TRACE(" to value %T ", value
);
413 llvm::Value
* bitmap8
= ir_builder()->CreateBitCast(
414 bitmap
, types()->ptr_type(types()->i8_type()), "bitMapCast");
415 AddFunctionCall("bitMapSetBit", types()->void_type(), {bitmap8
, position
, value
});
418 /// Return value of a bit in validity bitMap (handles null bitmaps too).
419 llvm::Value
* LLVMGenerator::GetPackedValidityBitValue(llvm::Value
* bitmap
,
420 llvm::Value
* position
) {
421 ADD_TRACE("fetch validity bit at position %T", position
);
423 llvm::Value
* bitmap8
= ir_builder()->CreateBitCast(
424 bitmap
, types()->ptr_type(types()->i8_type()), "bitMapCast");
425 return AddFunctionCall("bitMapValidityGetBit", types()->i1_type(), {bitmap8
, position
});
428 /// Clear the bit in bitMap if value = false.
429 void LLVMGenerator::ClearPackedBitValueIfFalse(llvm::Value
* bitmap
, llvm::Value
* position
,
430 llvm::Value
* value
) {
431 ADD_TRACE("ClearIfFalse bit at position %T", position
);
432 ADD_TRACE(" value %T ", value
);
434 llvm::Value
* bitmap8
= ir_builder()->CreateBitCast(
435 bitmap
, types()->ptr_type(types()->i8_type()), "bitMapCast");
436 AddFunctionCall("bitMapClearBitIfFalse", types()->void_type(),
437 {bitmap8
, position
, value
});
440 /// Extract the bitmap addresses, and do an intersection.
441 void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr
& compiled_expr
,
442 const EvalBatch
& eval_batch
,
443 const SelectionVector
* selection_vector
) {
444 auto validities
= compiled_expr
.value_validity()->validity_exprs();
446 // Extract all the source bitmap addresses.
447 BitMapAccumulator
accumulator(eval_batch
);
448 for (auto& validity_dex
: validities
) {
449 validity_dex
->Accept(accumulator
);
452 // Extract the destination bitmap address.
453 int out_idx
= compiled_expr
.output()->validity_idx();
454 uint8_t* dst_bitmap
= eval_batch
.GetBuffer(out_idx
);
455 // Compute the destination bitmap.
456 if (selection_vector
== nullptr) {
457 accumulator
.ComputeResult(dst_bitmap
);
459 /// The output bitmap is an intersection of some input/local bitmaps. However, with a
460 /// selection vector, only the bits corresponding to the indices in the selection
461 /// vector need to set in the output bitmap. This is done in two steps :
463 /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap.
464 /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap.
465 LocalBitMapsHolder
bit_map_holder(eval_batch
.num_records(), 1);
466 uint8_t* temp_bitmap
= bit_map_holder
.GetLocalBitMap(0);
467 accumulator
.ComputeResult(temp_bitmap
);
469 auto num_out_records
= selection_vector
->GetNumSlots();
470 // the memset isn't required, doing it just for valgrind.
471 memset(dst_bitmap
, 0, arrow::BitUtil::BytesForBits(num_out_records
));
472 for (auto i
= 0; i
< num_out_records
; ++i
) {
473 auto bit
= arrow::BitUtil::GetBit(temp_bitmap
, selection_vector
->GetIndex(i
));
474 arrow::BitUtil::SetBitTo(dst_bitmap
, i
, bit
);
479 llvm::Value
* LLVMGenerator::AddFunctionCall(const std::string
& full_name
,
480 llvm::Type
* ret_type
,
481 const std::vector
<llvm::Value
*>& args
) {
482 // find the llvm function.
483 llvm::Function
* fn
= module()->getFunction(full_name
);
484 DCHECK_NE(fn
, nullptr) << "missing function " << full_name
;
486 if (enable_ir_traces_
&& !full_name
.compare("printf") &&
487 !full_name
.compare("printff")) {
488 // Trace for debugging
489 ADD_TRACE("invoke native fn " + full_name
);
492 // build a call to the llvm function.
494 if (ret_type
->isVoidTy()) {
495 // void functions can't have a name for the call.
496 value
= ir_builder()->CreateCall(fn
, args
);
498 value
= ir_builder()->CreateCall(fn
, args
, full_name
);
499 DCHECK(value
->getType() == ret_type
);
505 std::shared_ptr
<DecimalLValue
> LLVMGenerator::BuildDecimalLValue(llvm::Value
* value
,
506 DataTypePtr arrow_type
) {
507 // only decimals of size 128-bit supported.
508 DCHECK(is_decimal_128(arrow_type
));
510 arrow::internal::checked_cast
<arrow::DecimalType
*>(arrow_type
.get());
511 return std::make_shared
<DecimalLValue
>(value
, nullptr,
512 types()->i32_constant(decimal_type
->precision()),
513 types()->i32_constant(decimal_type
->scale()));
516 #define ADD_VISITOR_TRACE(...) \
517 if (generator_->enable_ir_traces_) { \
518 generator_->AddTrace(__VA_ARGS__); \
521 // Visitor for generating the code for a decomposed expression.
522 LLVMGenerator::Visitor::Visitor(LLVMGenerator
* generator
, llvm::Function
* function
,
523 llvm::BasicBlock
* entry_block
, llvm::Value
* arg_addrs
,
524 llvm::Value
* arg_local_bitmaps
,
525 std::vector
<llvm::Value
*> slice_offsets
,
526 llvm::Value
* arg_context_ptr
, llvm::Value
* loop_var
)
527 : generator_(generator
),
529 entry_block_(entry_block
),
530 arg_addrs_(arg_addrs
),
531 arg_local_bitmaps_(arg_local_bitmaps
),
532 slice_offsets_(slice_offsets
),
533 arg_context_ptr_(arg_context_ptr
),
535 has_arena_allocs_(false) {
536 ADD_VISITOR_TRACE("Iteration %T", loop_var
);
539 void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex
& dex
) {
540 llvm::IRBuilder
<>* builder
= ir_builder();
541 llvm::Value
* slot_ref
= GetBufferReference(dex
.DataIdx(), kBufferTypeData
, dex
.Field());
542 llvm::Value
* slot_index
= builder
->CreateAdd(loop_var_
, GetSliceOffset(dex
.DataIdx()));
543 llvm::Value
* slot_value
;
544 std::shared_ptr
<LValue
> lvalue
;
546 switch (dex
.FieldType()->id()) {
547 case arrow::Type::BOOL
:
548 slot_value
= generator_
->GetPackedBitValue(slot_ref
, slot_index
);
549 lvalue
= std::make_shared
<LValue
>(slot_value
);
552 case arrow::Type::DECIMAL
: {
553 auto slot_offset
= CreateGEP(builder
, slot_ref
, slot_index
);
554 slot_value
= CreateLoad(builder
, slot_offset
, dex
.FieldName());
555 lvalue
= generator_
->BuildDecimalLValue(slot_value
, dex
.FieldType());
560 auto slot_offset
= CreateGEP(builder
, slot_ref
, slot_index
);
561 slot_value
= CreateLoad(builder
, slot_offset
, dex
.FieldName());
562 lvalue
= std::make_shared
<LValue
>(slot_value
);
566 ADD_VISITOR_TRACE("visit fixed-len data vector " + dex
.FieldName() + " value %T",
571 void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex
& dex
) {
572 llvm::IRBuilder
<>* builder
= ir_builder();
575 // compute len from the offsets array.
576 llvm::Value
* offsets_slot_ref
=
577 GetBufferReference(dex
.OffsetsIdx(), kBufferTypeOffsets
, dex
.Field());
578 llvm::Value
* offsets_slot_index
=
579 builder
->CreateAdd(loop_var_
, GetSliceOffset(dex
.OffsetsIdx()));
581 // => offset_start = offsets[loop_var]
582 slot
= CreateGEP(builder
, offsets_slot_ref
, offsets_slot_index
);
583 llvm::Value
* offset_start
= CreateLoad(builder
, slot
, "offset_start");
585 // => offset_end = offsets[loop_var + 1]
586 llvm::Value
* offsets_slot_index_next
= builder
->CreateAdd(
587 offsets_slot_index
, generator_
->types()->i64_constant(1), "loop_var+1");
588 slot
= CreateGEP(builder
, offsets_slot_ref
, offsets_slot_index_next
);
589 llvm::Value
* offset_end
= CreateLoad(builder
, slot
, "offset_end");
591 // => len_value = offset_end - offset_start
592 llvm::Value
* len_value
=
593 builder
->CreateSub(offset_end
, offset_start
, dex
.FieldName() + "Len");
595 // get the data from the data array, at offset 'offset_start'.
596 llvm::Value
* data_slot_ref
=
597 GetBufferReference(dex
.DataIdx(), kBufferTypeData
, dex
.Field());
598 llvm::Value
* data_value
= CreateGEP(builder
, data_slot_ref
, offset_start
);
599 ADD_VISITOR_TRACE("visit var-len data vector " + dex
.FieldName() + " len %T",
601 result_
.reset(new LValue(data_value
, len_value
));
604 void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex
& dex
) {
605 llvm::IRBuilder
<>* builder
= ir_builder();
606 llvm::Value
* slot_ref
=
607 GetBufferReference(dex
.ValidityIdx(), kBufferTypeValidity
, dex
.Field());
608 llvm::Value
* slot_index
=
609 builder
->CreateAdd(loop_var_
, GetSliceOffset(dex
.ValidityIdx()));
610 llvm::Value
* validity
= generator_
->GetPackedValidityBitValue(slot_ref
, slot_index
);
612 ADD_VISITOR_TRACE("visit validity vector " + dex
.FieldName() + " value %T", validity
);
613 result_
.reset(new LValue(validity
));
616 void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex
& dex
) {
617 llvm::Value
* slot_ref
= GetLocalBitMapReference(dex
.local_bitmap_idx());
618 llvm::Value
* validity
= generator_
->GetPackedBitValue(slot_ref
, loop_var_
);
621 "visit local bitmap " + std::to_string(dex
.local_bitmap_idx()) + " value %T",
623 result_
.reset(new LValue(validity
));
626 void LLVMGenerator::Visitor::Visit(const TrueDex
& dex
) {
627 result_
.reset(new LValue(generator_
->types()->true_constant()));
630 void LLVMGenerator::Visitor::Visit(const FalseDex
& dex
) {
631 result_
.reset(new LValue(generator_
->types()->false_constant()));
634 void LLVMGenerator::Visitor::Visit(const LiteralDex
& dex
) {
635 LLVMTypes
* types
= generator_
->types();
636 llvm::Value
* value
= nullptr;
637 llvm::Value
* len
= nullptr;
639 switch (dex
.type()->id()) {
640 case arrow::Type::BOOL
:
641 value
= types
->i1_constant(arrow::util::get
<bool>(dex
.holder()));
644 case arrow::Type::UINT8
:
645 value
= types
->i8_constant(arrow::util::get
<uint8_t>(dex
.holder()));
648 case arrow::Type::UINT16
:
649 value
= types
->i16_constant(arrow::util::get
<uint16_t>(dex
.holder()));
652 case arrow::Type::UINT32
:
653 value
= types
->i32_constant(arrow::util::get
<uint32_t>(dex
.holder()));
656 case arrow::Type::UINT64
:
657 value
= types
->i64_constant(arrow::util::get
<uint64_t>(dex
.holder()));
660 case arrow::Type::INT8
:
661 value
= types
->i8_constant(arrow::util::get
<int8_t>(dex
.holder()));
664 case arrow::Type::INT16
:
665 value
= types
->i16_constant(arrow::util::get
<int16_t>(dex
.holder()));
668 case arrow::Type::FLOAT
:
669 value
= types
->float_constant(arrow::util::get
<float>(dex
.holder()));
672 case arrow::Type::DOUBLE
:
673 value
= types
->double_constant(arrow::util::get
<double>(dex
.holder()));
676 case arrow::Type::STRING
:
677 case arrow::Type::BINARY
: {
678 const std::string
& str
= arrow::util::get
<std::string
>(dex
.holder());
680 llvm::Constant
* str_int_cast
= types
->i64_constant((int64_t)str
.c_str());
681 value
= llvm::ConstantExpr::getIntToPtr(str_int_cast
, types
->i8_ptr_type());
682 len
= types
->i32_constant(static_cast<int32_t>(str
.length()));
686 case arrow::Type::INT32
:
687 case arrow::Type::DATE32
:
688 case arrow::Type::TIME32
:
689 case arrow::Type::INTERVAL_MONTHS
:
690 value
= types
->i32_constant(arrow::util::get
<int32_t>(dex
.holder()));
693 case arrow::Type::INT64
:
694 case arrow::Type::DATE64
:
695 case arrow::Type::TIME64
:
696 case arrow::Type::TIMESTAMP
:
697 case arrow::Type::INTERVAL_DAY_TIME
:
698 value
= types
->i64_constant(arrow::util::get
<int64_t>(dex
.holder()));
701 case arrow::Type::DECIMAL
: {
702 // build code for struct
703 auto scalar
= arrow::util::get
<DecimalScalar128
>(dex
.holder());
704 // ConstantInt doesn't have a get method that takes int128 or a pair of int64. so,
705 // passing the string representation instead.
707 llvm::ConstantInt::get(llvm::Type::getInt128Ty(*generator_
->context()),
708 Decimal128(scalar
.value()).ToIntegerString(), 10);
709 auto type
= arrow::decimal(scalar
.precision(), scalar
.scale());
710 auto lvalue
= generator_
->BuildDecimalLValue(int128_value
, type
);
711 // set it as the l-value and return.
719 ADD_VISITOR_TRACE("visit Literal %T", value
);
720 result_
.reset(new LValue(value
, len
));
723 void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex
& dex
) {
724 const std::string
& function_name
= dex
.func_descriptor()->name();
725 ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name
);
727 const NativeFunction
* native_function
= dex
.native_function();
729 // build the function params (ignore validity).
730 auto params
= BuildParams(dex
.function_holder().get(), dex
.args(), false,
731 native_function
->NeedsContext());
733 auto arrow_return_type
= dex
.func_descriptor()->return_type();
734 if (native_function
->CanReturnErrors()) {
735 // slow path : if a function can return errors, skip invoking the function
736 // unless all of the input args are valid. Otherwise, it can cause spurious errors.
738 llvm::IRBuilder
<>* builder
= ir_builder();
739 LLVMTypes
* types
= generator_
->types();
740 auto arrow_type_id
= arrow_return_type
->id();
741 auto result_type
= types
->IRType(arrow_type_id
);
743 // Build combined validity of the args.
744 llvm::Value
* is_valid
= types
->true_constant();
745 for (auto& pair
: dex
.args()) {
746 auto arg_validity
= BuildCombinedValidity(pair
->validity_exprs());
747 is_valid
= builder
->CreateAnd(is_valid
, arg_validity
, "validityBitAnd");
751 auto then_lambda
= [&] {
752 ADD_VISITOR_TRACE("fn " + function_name
+
753 " can return errors : all args valid, invoke fn");
754 return BuildFunctionCall(native_function
, arrow_return_type
, ¶ms
);
758 auto else_lambda
= [&] {
759 ADD_VISITOR_TRACE("fn " + function_name
+
760 " can return errors : not all args valid, return dummy value");
761 llvm::Value
* else_value
= types
->NullConstant(result_type
);
762 llvm::Value
* else_value_len
= nullptr;
763 if (arrow::is_binary_like(arrow_type_id
)) {
764 else_value_len
= types
->i32_constant(0);
766 return std::make_shared
<LValue
>(else_value
, else_value_len
);
769 result_
= BuildIfElse(is_valid
, then_lambda
, else_lambda
, arrow_return_type
);
771 // fast path : invoke function without computing validities.
772 result_
= BuildFunctionCall(native_function
, arrow_return_type
, ¶ms
);
776 void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex
& dex
) {
777 ADD_VISITOR_TRACE("visit NullableNever base function " + dex
.func_descriptor()->name());
778 const NativeFunction
* native_function
= dex
.native_function();
780 // build function params along with validity.
781 auto params
= BuildParams(dex
.function_holder().get(), dex
.args(), true,
782 native_function
->NeedsContext());
784 auto arrow_return_type
= dex
.func_descriptor()->return_type();
785 result_
= BuildFunctionCall(native_function
, arrow_return_type
, ¶ms
);
788 void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex
& dex
) {
789 ADD_VISITOR_TRACE("visit NullableInternal base function " +
790 dex
.func_descriptor()->name());
791 llvm::IRBuilder
<>* builder
= ir_builder();
792 LLVMTypes
* types
= generator_
->types();
794 const NativeFunction
* native_function
= dex
.native_function();
796 // build function params along with validity.
797 auto params
= BuildParams(dex
.function_holder().get(), dex
.args(), true,
798 native_function
->NeedsContext());
800 // add an extra arg for validity (allocated on stack).
801 llvm::AllocaInst
* result_valid_ptr
=
802 new llvm::AllocaInst(types
->i8_type(), 0, "result_valid", entry_block_
);
803 params
.push_back(result_valid_ptr
);
805 auto arrow_return_type
= dex
.func_descriptor()->return_type();
806 result_
= BuildFunctionCall(native_function
, arrow_return_type
, ¶ms
);
808 // load the result validity and truncate to i1.
809 llvm::Value
* result_valid_i8
= CreateLoad(builder
, result_valid_ptr
);
810 llvm::Value
* result_valid
= builder
->CreateTrunc(result_valid_i8
, types
->i1_type());
812 // set validity bit in the local bitmap.
813 ClearLocalBitMapIfNotValid(dex
.local_bitmap_idx(), result_valid
);
816 void LLVMGenerator::Visitor::Visit(const IfDex
& dex
) {
817 ADD_VISITOR_TRACE("visit IfExpression");
818 llvm::IRBuilder
<>* builder
= ir_builder();
820 // Evaluate condition.
821 LValuePtr if_condition
= BuildValueAndValidity(dex
.condition_vv());
823 // Check if the result is valid, and there is match.
824 llvm::Value
* validAndMatched
=
825 builder
->CreateAnd(if_condition
->data(), if_condition
->validity(), "validAndMatch");
828 auto then_lambda
= [&] {
829 ADD_VISITOR_TRACE("branch to then block");
830 LValuePtr then_lvalue
= BuildValueAndValidity(dex
.then_vv());
831 ClearLocalBitMapIfNotValid(dex
.local_bitmap_idx(), then_lvalue
->validity());
832 ADD_VISITOR_TRACE("IfExpression result validity %T in matching then",
833 then_lvalue
->validity());
838 auto else_lambda
= [&] {
839 LValuePtr else_lvalue
;
840 if (dex
.is_terminal_else()) {
841 ADD_VISITOR_TRACE("branch to terminal else block");
843 else_lvalue
= BuildValueAndValidity(dex
.else_vv());
844 // update the local bitmap with the validity.
845 ClearLocalBitMapIfNotValid(dex
.local_bitmap_idx(), else_lvalue
->validity());
846 ADD_VISITOR_TRACE("IfExpression result validity %T in terminal else",
847 else_lvalue
->validity());
849 ADD_VISITOR_TRACE("branch to non-terminal else block");
851 // this is a non-terminal else. let the child (nested if/else) handle validity.
852 auto value_expr
= dex
.else_vv().value_expr();
853 value_expr
->Accept(*this);
854 else_lvalue
= result();
859 // build the if-else condition.
860 result_
= BuildIfElse(validAndMatched
, then_lambda
, else_lambda
, dex
.result_type());
861 if (arrow::is_binary_like(dex
.result_type()->id())) {
862 ADD_VISITOR_TRACE("IfElse result length %T", result_
->length());
864 ADD_VISITOR_TRACE("IfElse result value %T", result_
->data());
868 // if any arg is valid and false,
869 // short-circuit and return FALSE (value=false, valid=true)
870 // else if all args are valid and true
871 // return TRUE (value=true, valid=true)
873 // return NULL (value=true, valid=false)
875 void LLVMGenerator::Visitor::Visit(const BooleanAndDex
& dex
) {
876 ADD_VISITOR_TRACE("visit BooleanAndExpression");
877 llvm::IRBuilder
<>* builder
= ir_builder();
878 LLVMTypes
* types
= generator_
->types();
879 llvm::LLVMContext
* context
= generator_
->context();
881 // Create blocks for short-circuit.
882 llvm::BasicBlock
* short_circuit_bb
=
883 llvm::BasicBlock::Create(*context
, "short_circuit", function_
);
884 llvm::BasicBlock
* non_short_circuit_bb
=
885 llvm::BasicBlock::Create(*context
, "non_short_circuit", function_
);
886 llvm::BasicBlock
* merge_bb
= llvm::BasicBlock::Create(*context
, "merge", function_
);
888 llvm::Value
* all_exprs_valid
= types
->true_constant();
889 for (auto& pair
: dex
.args()) {
890 LValuePtr current
= BuildValueAndValidity(*pair
);
892 ADD_VISITOR_TRACE("BooleanAndExpression arg value %T", current
->data());
893 ADD_VISITOR_TRACE("BooleanAndExpression arg validity %T", current
->validity());
895 // short-circuit if valid and false
896 llvm::Value
* is_false
= builder
->CreateNot(current
->data());
897 llvm::Value
* valid_and_false
=
898 builder
->CreateAnd(is_false
, current
->validity(), "valid_and_false");
900 llvm::BasicBlock
* else_bb
= llvm::BasicBlock::Create(*context
, "else", function_
);
901 builder
->CreateCondBr(valid_and_false
, short_circuit_bb
, else_bb
);
903 // Emit the else block.
904 builder
->SetInsertPoint(else_bb
);
905 // remember if any nulls were encountered.
907 builder
->CreateAnd(all_exprs_valid
, current
->validity(), "validityBitAnd");
908 // continue to evaluate the next pair in list.
910 builder
->CreateBr(non_short_circuit_bb
);
912 // Short-circuit case (at least one of the expressions is valid and false).
913 // No need to set validity bit (valid by default).
914 builder
->SetInsertPoint(short_circuit_bb
);
915 ADD_VISITOR_TRACE("BooleanAndExpression result value false");
916 ADD_VISITOR_TRACE("BooleanAndExpression result validity true");
917 builder
->CreateBr(merge_bb
);
919 // non short-circuit case (All expressions are either true or null).
920 // result valid if all of the exprs are non-null.
921 builder
->SetInsertPoint(non_short_circuit_bb
);
922 ClearLocalBitMapIfNotValid(dex
.local_bitmap_idx(), all_exprs_valid
);
923 ADD_VISITOR_TRACE("BooleanAndExpression result value true");
924 ADD_VISITOR_TRACE("BooleanAndExpression result validity %T", all_exprs_valid
);
925 builder
->CreateBr(merge_bb
);
927 builder
->SetInsertPoint(merge_bb
);
928 llvm::PHINode
* result_value
= builder
->CreatePHI(types
->i1_type(), 2, "res_value");
929 result_value
->addIncoming(types
->false_constant(), short_circuit_bb
);
930 result_value
->addIncoming(types
->true_constant(), non_short_circuit_bb
);
931 result_
.reset(new LValue(result_value
));
935 // if any arg is valid and true,
936 // short-circuit and return TRUE (value=true, valid=true)
937 // else if all args are valid and false
938 // return FALSE (value=false, valid=true)
940 // return NULL (value=false, valid=false)
942 void LLVMGenerator::Visitor::Visit(const BooleanOrDex
& dex
) {
943 ADD_VISITOR_TRACE("visit BooleanOrExpression");
944 llvm::IRBuilder
<>* builder
= ir_builder();
945 LLVMTypes
* types
= generator_
->types();
946 llvm::LLVMContext
* context
= generator_
->context();
948 // Create blocks for short-circuit.
949 llvm::BasicBlock
* short_circuit_bb
=
950 llvm::BasicBlock::Create(*context
, "short_circuit", function_
);
951 llvm::BasicBlock
* non_short_circuit_bb
=
952 llvm::BasicBlock::Create(*context
, "non_short_circuit", function_
);
953 llvm::BasicBlock
* merge_bb
= llvm::BasicBlock::Create(*context
, "merge", function_
);
955 llvm::Value
* all_exprs_valid
= types
->true_constant();
956 for (auto& pair
: dex
.args()) {
957 LValuePtr current
= BuildValueAndValidity(*pair
);
959 ADD_VISITOR_TRACE("BooleanOrExpression arg value %T", current
->data());
960 ADD_VISITOR_TRACE("BooleanOrExpression arg validity %T", current
->validity());
962 // short-circuit if valid and true.
963 llvm::Value
* valid_and_true
=
964 builder
->CreateAnd(current
->data(), current
->validity(), "valid_and_true");
966 llvm::BasicBlock
* else_bb
= llvm::BasicBlock::Create(*context
, "else", function_
);
967 builder
->CreateCondBr(valid_and_true
, short_circuit_bb
, else_bb
);
969 // Emit the else block.
970 builder
->SetInsertPoint(else_bb
);
971 // remember if any nulls were encountered.
973 builder
->CreateAnd(all_exprs_valid
, current
->validity(), "validityBitAnd");
974 // continue to evaluate the next pair in list.
976 builder
->CreateBr(non_short_circuit_bb
);
978 // Short-circuit case (at least one of the expressions is valid and true).
979 // No need to set validity bit (valid by default).
980 builder
->SetInsertPoint(short_circuit_bb
);
981 ADD_VISITOR_TRACE("BooleanOrExpression result value true");
982 ADD_VISITOR_TRACE("BooleanOrExpression result validity true");
983 builder
->CreateBr(merge_bb
);
985 // non short-circuit case (All expressions are either false or null).
986 // result valid if all of the exprs are non-null.
987 builder
->SetInsertPoint(non_short_circuit_bb
);
988 ClearLocalBitMapIfNotValid(dex
.local_bitmap_idx(), all_exprs_valid
);
989 ADD_VISITOR_TRACE("BooleanOrExpression result value false");
990 ADD_VISITOR_TRACE("BooleanOrExpression result validity %T", all_exprs_valid
);
991 builder
->CreateBr(merge_bb
);
993 builder
->SetInsertPoint(merge_bb
);
994 llvm::PHINode
* result_value
= builder
->CreatePHI(types
->i1_type(), 2, "res_value");
995 result_value
->addIncoming(types
->true_constant(), short_circuit_bb
);
996 result_value
->addIncoming(types
->false_constant(), non_short_circuit_bb
);
997 result_
.reset(new LValue(result_value
));
1000 template <typename Type
>
1001 void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase
<Type
>& dex
) {
1002 ADD_VISITOR_TRACE("visit In Expression");
1003 LLVMTypes
* types
= generator_
->types();
1004 std::vector
<llvm::Value
*> params
;
1006 const InExprDex
<Type
>& dex_instance
= dynamic_cast<const InExprDex
<Type
>&>(dex
);
1007 /* add the holder at the beginning */
1008 llvm::Constant
* ptr_int_cast
=
1009 types
->i64_constant((int64_t)(dex_instance
.in_holder().get()));
1010 params
.push_back(ptr_int_cast
);
1012 /* eval expr result */
1013 for (auto& pair
: dex
.args()) {
1014 DexPtr value_expr
= pair
->value_expr();
1015 value_expr
->Accept(*this);
1016 LValue
& result_ref
= *result();
1017 params
.push_back(result_ref
.data());
1019 /* length if the result is a string */
1020 if (result_ref
.length() != nullptr) {
1021 params
.push_back(result_ref
.length());
1024 /* push the validity of eval expr result */
1025 llvm::Value
* validity_expr
= BuildCombinedValidity(pair
->validity_exprs());
1026 params
.push_back(validity_expr
);
1029 llvm::Type
* ret_type
= types
->IRType(arrow::Type::type::BOOL
);
1033 value
= generator_
->AddFunctionCall(dex
.runtime_function(), ret_type
, params
);
1035 result_
.reset(new LValue(value
));
1039 void LLVMGenerator::Visitor::VisitInExpression
<gandiva::DecimalScalar128
>(
1040 const InExprDexBase
<gandiva::DecimalScalar128
>& dex
) {
1041 ADD_VISITOR_TRACE("visit In Expression");
1042 LLVMTypes
* types
= generator_
->types();
1043 std::vector
<llvm::Value
*> params
;
1044 DecimalIR
decimalIR(generator_
->engine_
.get());
1046 const InExprDex
<gandiva::DecimalScalar128
>& dex_instance
=
1047 dynamic_cast<const InExprDex
<gandiva::DecimalScalar128
>&>(dex
);
1048 /* add the holder at the beginning */
1049 llvm::Constant
* ptr_int_cast
=
1050 types
->i64_constant((int64_t)(dex_instance
.in_holder().get()));
1051 params
.push_back(ptr_int_cast
);
1053 /* eval expr result */
1054 for (auto& pair
: dex
.args()) {
1055 DexPtr value_expr
= pair
->value_expr();
1056 value_expr
->Accept(*this);
1057 LValue
& result_ref
= *result();
1058 params
.push_back(result_ref
.data());
1060 llvm::Constant
* precision
= types
->i32_constant(dex
.get_precision());
1061 llvm::Constant
* scale
= types
->i32_constant(dex
.get_scale());
1062 params
.push_back(precision
);
1063 params
.push_back(scale
);
1065 /* push the validity of eval expr result */
1066 llvm::Value
* validity_expr
= BuildCombinedValidity(pair
->validity_exprs());
1067 params
.push_back(validity_expr
);
1070 llvm::Type
* ret_type
= types
->IRType(arrow::Type::type::BOOL
);
1074 value
= decimalIR
.CallDecimalFunction(dex
.runtime_function(), ret_type
, params
);
1076 result_
.reset(new LValue(value
));
1079 void LLVMGenerator::Visitor::Visit(const InExprDexBase
<int32_t>& dex
) {
1080 VisitInExpression
<int32_t>(dex
);
1083 void LLVMGenerator::Visitor::Visit(const InExprDexBase
<int64_t>& dex
) {
1084 VisitInExpression
<int64_t>(dex
);
1087 void LLVMGenerator::Visitor::Visit(const InExprDexBase
<float>& dex
) {
1088 VisitInExpression
<float>(dex
);
1090 void LLVMGenerator::Visitor::Visit(const InExprDexBase
<double>& dex
) {
1091 VisitInExpression
<double>(dex
);
1094 void LLVMGenerator::Visitor::Visit(const InExprDexBase
<gandiva::DecimalScalar128
>& dex
) {
1095 VisitInExpression
<gandiva::DecimalScalar128
>(dex
);
1098 void LLVMGenerator::Visitor::Visit(const InExprDexBase
<std::string
>& dex
) {
1099 VisitInExpression
<std::string
>(dex
);
1102 LValuePtr
LLVMGenerator::Visitor::BuildIfElse(llvm::Value
* condition
,
1103 std::function
<LValuePtr()> then_func
,
1104 std::function
<LValuePtr()> else_func
,
1105 DataTypePtr result_type
) {
1106 llvm::IRBuilder
<>* builder
= ir_builder();
1107 llvm::LLVMContext
* context
= generator_
->context();
1108 LLVMTypes
* types
= generator_
->types();
1110 // Create blocks for the then, else and merge cases.
1111 llvm::BasicBlock
* then_bb
= llvm::BasicBlock::Create(*context
, "then", function_
);
1112 llvm::BasicBlock
* else_bb
= llvm::BasicBlock::Create(*context
, "else", function_
);
1113 llvm::BasicBlock
* merge_bb
= llvm::BasicBlock::Create(*context
, "merge", function_
);
1115 builder
->CreateCondBr(condition
, then_bb
, else_bb
);
1117 // Emit the then block.
1118 builder
->SetInsertPoint(then_bb
);
1119 LValuePtr then_lvalue
= then_func();
1120 builder
->CreateBr(merge_bb
);
1122 // refresh then_bb for phi (could have changed due to code generation of then_vv).
1123 then_bb
= builder
->GetInsertBlock();
1125 // Emit the else block.
1126 builder
->SetInsertPoint(else_bb
);
1127 LValuePtr else_lvalue
= else_func();
1128 builder
->CreateBr(merge_bb
);
1130 // refresh else_bb for phi (could have changed due to code generation of else_vv).
1131 else_bb
= builder
->GetInsertBlock();
1133 // Emit the merge block.
1134 builder
->SetInsertPoint(merge_bb
);
1135 auto llvm_type
= types
->IRType(result_type
->id());
1136 llvm::PHINode
* result_value
= builder
->CreatePHI(llvm_type
, 2, "res_value");
1137 result_value
->addIncoming(then_lvalue
->data(), then_bb
);
1138 result_value
->addIncoming(else_lvalue
->data(), else_bb
);
1141 switch (result_type
->id()) {
1142 case arrow::Type::STRING
:
1143 case arrow::Type::BINARY
: {
1144 llvm::PHINode
* result_length
;
1145 result_length
= builder
->CreatePHI(types
->i32_type(), 2, "res_length");
1146 result_length
->addIncoming(then_lvalue
->length(), then_bb
);
1147 result_length
->addIncoming(else_lvalue
->length(), else_bb
);
1148 ret
= std::make_shared
<LValue
>(result_value
, result_length
);
1152 case arrow::Type::DECIMAL
:
1153 ret
= generator_
->BuildDecimalLValue(result_value
, result_type
);
1157 ret
= std::make_shared
<LValue
>(result_value
);
1163 LValuePtr
LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair
& pair
) {
1164 // generate code for value
1165 auto value_expr
= pair
.value_expr();
1166 value_expr
->Accept(*this);
1167 auto value
= result()->data();
1168 auto length
= result()->length();
1170 // generate code for validity
1171 auto validity
= BuildCombinedValidity(pair
.validity_exprs());
1173 return std::make_shared
<LValue
>(value
, length
, validity
);
1176 LValuePtr
LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction
* func
,
1177 DataTypePtr arrow_return_type
,
1178 std::vector
<llvm::Value
*>* params
) {
1179 auto types
= generator_
->types();
1180 auto arrow_return_type_id
= arrow_return_type
->id();
1181 auto llvm_return_type
= types
->IRType(arrow_return_type_id
);
1182 DecimalIR
decimalIR(generator_
->engine_
.get());
1184 if (arrow_return_type_id
== arrow::Type::DECIMAL
) {
1185 // For decimal fns, the output precision/scale are passed along as parameters.
1187 // convert from this :
1188 // out = add_decimal(v1, p1, s1, v2, p2, s2)
1190 // out = add_decimal(v1, p1, s1, v2, p2, s2, out_p, out_s)
1192 // Append the out_precision and out_scale
1193 auto ret_lvalue
= generator_
->BuildDecimalLValue(nullptr, arrow_return_type
);
1194 params
->push_back(ret_lvalue
->precision());
1195 params
->push_back(ret_lvalue
->scale());
1197 // Make the function call
1198 auto out
= decimalIR
.CallDecimalFunction(func
->pc_name(), llvm_return_type
, *params
);
1199 ret_lvalue
->set_data(out
);
1200 return std::move(ret_lvalue
);
1202 bool isDecimalFunction
= false;
1203 for (auto& arg
: *params
) {
1204 if (arg
->getType() == types
->i128_type()) {
1205 isDecimalFunction
= true;
1208 // add extra arg for return length for variable len return types (allocated on stack).
1209 llvm::AllocaInst
* result_len_ptr
= nullptr;
1210 if (arrow::is_binary_like(arrow_return_type_id
)) {
1211 result_len_ptr
= new llvm::AllocaInst(generator_
->types()->i32_type(), 0,
1212 "result_len", entry_block_
);
1213 params
->push_back(result_len_ptr
);
1214 has_arena_allocs_
= true;
1217 // Make the function call
1218 llvm::IRBuilder
<>* builder
= ir_builder();
1221 ? decimalIR
.CallDecimalFunction(func
->pc_name(), llvm_return_type
, *params
)
1222 : generator_
->AddFunctionCall(func
->pc_name(), llvm_return_type
, *params
);
1224 (result_len_ptr
== nullptr) ? nullptr : CreateLoad(builder
, result_len_ptr
);
1225 return std::make_shared
<LValue
>(value
, value_len
);
1229 std::vector
<llvm::Value
*> LLVMGenerator::Visitor::BuildParams(
1230 FunctionHolder
* holder
, const ValueValidityPairVector
& args
, bool with_validity
,
1231 bool with_context
) {
1232 LLVMTypes
* types
= generator_
->types();
1233 std::vector
<llvm::Value
*> params
;
1235 // add context if required.
1237 params
.push_back(arg_context_ptr_
);
1240 // if the function has holder, add the holder pointer.
1241 if (holder
!= nullptr) {
1242 auto ptr
= types
->i64_constant((int64_t)holder
);
1243 params
.push_back(ptr
);
1246 // build the function params, along with the validities.
1247 for (auto& pair
: args
) {
1249 DexPtr value_expr
= pair
->value_expr();
1250 value_expr
->Accept(*this);
1251 LValue
& result_ref
= *result();
1253 // append all the parameters corresponding to this LValue.
1254 result_ref
.AppendFunctionParams(¶ms
);
1257 if (with_validity
) {
1258 llvm::Value
* validity_expr
= BuildCombinedValidity(pair
->validity_exprs());
1259 params
.push_back(validity_expr
);
1266 // Bitwise-AND of a vector of bits to get the combined validity.
1267 llvm::Value
* LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector
& validities
) {
1268 llvm::IRBuilder
<>* builder
= ir_builder();
1269 LLVMTypes
* types
= generator_
->types();
1271 llvm::Value
* isValid
= types
->true_constant();
1272 for (auto& dex
: validities
) {
1274 isValid
= builder
->CreateAnd(isValid
, result()->data(), "validityBitAnd");
1276 ADD_VISITOR_TRACE("combined validity is %T", isValid
);
1280 llvm::Value
* LLVMGenerator::Visitor::GetBufferReference(int idx
, BufferType buffer_type
,
1282 llvm::IRBuilder
<>* builder
= ir_builder();
1284 // Switch to the entry block to create a reference.
1285 llvm::BasicBlock
* saved_block
= builder
->GetInsertBlock();
1286 builder
->SetInsertPoint(entry_block_
);
1288 llvm::Value
* slot_ref
= nullptr;
1289 switch (buffer_type
) {
1290 case kBufferTypeValidity
:
1291 slot_ref
= generator_
->GetValidityReference(arg_addrs_
, idx
, field
);
1294 case kBufferTypeData
:
1295 slot_ref
= generator_
->GetDataReference(arg_addrs_
, idx
, field
);
1298 case kBufferTypeOffsets
:
1299 slot_ref
= generator_
->GetOffsetsReference(arg_addrs_
, idx
, field
);
1303 // Revert to the saved block.
1304 builder
->SetInsertPoint(saved_block
);
1308 llvm::Value
* LLVMGenerator::Visitor::GetSliceOffset(int idx
) {
1309 return slice_offsets_
[idx
];
1312 llvm::Value
* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx
) {
1313 llvm::IRBuilder
<>* builder
= ir_builder();
1315 // Switch to the entry block to create a reference.
1316 llvm::BasicBlock
* saved_block
= builder
->GetInsertBlock();
1317 builder
->SetInsertPoint(entry_block_
);
1319 llvm::Value
* slot_ref
= generator_
->GetLocalBitMapReference(arg_local_bitmaps_
, idx
);
1321 // Revert to the saved block.
1322 builder
->SetInsertPoint(saved_block
);
1326 /// The local bitmap is pre-filled with 1s. Clear only if invalid.
1327 void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx
,
1328 llvm::Value
* is_valid
) {
1329 llvm::Value
* slot_ref
= GetLocalBitMapReference(local_bitmap_idx
);
1330 generator_
->ClearPackedBitValueIfFalse(slot_ref
, loop_var_
, is_valid
);
1333 // Hooks for tracing/printfs.
1335 // replace %T with the type-specific format specifier.
1336 // For some reason, float/double literals are getting lost when printing with the generic
1337 // printf. so, use a wrapper instead.
1338 std::string
LLVMGenerator::ReplaceFormatInTrace(const std::string
& in_msg
,
1340 std::string
* print_fn
) {
1341 std::string msg
= in_msg
;
1342 std::size_t pos
= msg
.find("%T");
1343 if (pos
== std::string::npos
) {
1348 llvm::Type
* type
= value
->getType();
1349 const char* fmt
= "";
1350 if (type
->isIntegerTy(1) || type
->isIntegerTy(8) || type
->isIntegerTy(16) ||
1351 type
->isIntegerTy(32)) {
1353 } else if (type
->isIntegerTy(64)) {
1356 } else if (type
->isFloatTy()) {
1359 *print_fn
= "print_float";
1360 } else if (type
->isDoubleTy()) {
1363 *print_fn
= "print_double";
1364 } else if (type
->isPointerTy()) {
1370 msg
.replace(pos
, 2, fmt
);
1374 void LLVMGenerator::AddTrace(const std::string
& msg
, llvm::Value
* value
) {
1375 if (!enable_ir_traces_
) {
1379 std::string dmsg
= "IR_TRACE:: " + msg
+ "\n";
1380 std::string print_fn_name
= "printf";
1381 if (value
!= nullptr) {
1382 dmsg
= ReplaceFormatInTrace(dmsg
, value
, &print_fn_name
);
1384 trace_strings_
.push_back(dmsg
);
1386 // cast this to an llvm pointer.
1387 const char* str
= trace_strings_
.back().c_str();
1388 llvm::Constant
* str_int_cast
= types()->i64_constant((int64_t)str
);
1389 llvm::Constant
* str_ptr_cast
=
1390 llvm::ConstantExpr::getIntToPtr(str_int_cast
, types()->i8_ptr_type());
1392 std::vector
<llvm::Value
*> args
;
1393 args
.push_back(str_ptr_cast
);
1394 if (value
!= nullptr) {
1395 args
.push_back(value
);
1397 AddFunctionCall(print_fn_name
, types()->i32_type(), args
);
1400 } // namespace gandiva