]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/gandiva/llvm_generator.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / gandiva / llvm_generator.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "gandiva/llvm_generator.h"
19
20 #include <fstream>
21 #include <iostream>
22 #include <sstream>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "gandiva/bitmap_accumulator.h"
28 #include "gandiva/decimal_ir.h"
29 #include "gandiva/dex.h"
30 #include "gandiva/expr_decomposer.h"
31 #include "gandiva/expression.h"
32 #include "gandiva/lvalue.h"
33
34 namespace gandiva {
35
36 #define ADD_TRACE(...) \
37 if (enable_ir_traces_) { \
38 AddTrace(__VA_ARGS__); \
39 }
40
41 LLVMGenerator::LLVMGenerator() : enable_ir_traces_(false) {}
42
43 Status LLVMGenerator::Make(std::shared_ptr<Configuration> config,
44 std::unique_ptr<LLVMGenerator>* llvm_generator) {
45 std::unique_ptr<LLVMGenerator> llvmgen_obj(new LLVMGenerator());
46
47 ARROW_RETURN_NOT_OK(Engine::Make(config, &(llvmgen_obj->engine_)));
48 *llvm_generator = std::move(llvmgen_obj);
49
50 return Status::OK();
51 }
52
53 Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) {
54 int idx = static_cast<int>(compiled_exprs_.size());
55 // decompose the expression to separate out value and validities.
56 ExprDecomposer decomposer(function_registry_, annotator_);
57 ValueValidityPairPtr value_validity;
58 ARROW_RETURN_NOT_OK(decomposer.Decompose(*expr->root(), &value_validity));
59 // Generate the IR function for the decomposed expression.
60 std::unique_ptr<CompiledExpr> compiled_expr(new CompiledExpr(value_validity, output));
61 llvm::Function* ir_function = nullptr;
62 ARROW_RETURN_NOT_OK(CodeGenExprValue(value_validity->value_expr(),
63 annotator_.buffer_count(), output, idx,
64 &ir_function, selection_vector_mode_));
65 compiled_expr->SetIRFunction(selection_vector_mode_, ir_function);
66
67 compiled_exprs_.push_back(std::move(compiled_expr));
68 return Status::OK();
69 }
70
71 /// Build and optimise module for projection expression.
72 Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode mode) {
73 selection_vector_mode_ = mode;
74 for (auto& expr : exprs) {
75 auto output = annotator_.AddOutputFieldDescriptor(expr->result());
76 ARROW_RETURN_NOT_OK(Add(expr, output));
77 }
78
79 // Compile and inject into the process' memory the generated function.
80 ARROW_RETURN_NOT_OK(engine_->FinalizeModule());
81
82 // setup the jit functions for each expression.
83 for (auto& compiled_expr : compiled_exprs_) {
84 auto ir_fn = compiled_expr->GetIRFunction(mode);
85 auto jit_fn = reinterpret_cast<EvalFunc>(engine_->CompiledFunction(ir_fn));
86 compiled_expr->SetJITFunction(selection_vector_mode_, jit_fn);
87 }
88
89 return Status::OK();
90 }
91
92 /// Execute the compiled module against the provided vectors.
93 Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
94 const ArrayDataVector& output_vector) {
95 return Execute(record_batch, nullptr, output_vector);
96 }
97
98 /// Execute the compiled module against the provided vectors based on the type of
99 /// selection vector.
100 Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
101 const SelectionVector* selection_vector,
102 const ArrayDataVector& output_vector) {
103 DCHECK_GT(record_batch.num_rows(), 0);
104
105 auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector);
106 DCHECK_GT(eval_batch->GetNumBuffers(), 0);
107
108 auto mode = SelectionVector::MODE_NONE;
109 if (selection_vector != nullptr) {
110 mode = selection_vector->GetMode();
111 }
112 if (mode != selection_vector_mode_) {
113 return Status::Invalid("llvm expression built for selection vector mode ",
114 selection_vector_mode_, " received vector with mode ", mode);
115 }
116
117 for (auto& compiled_expr : compiled_exprs_) {
118 // generate data/offset vectors.
119 const uint8_t* selection_buffer = nullptr;
120 auto num_output_rows = record_batch.num_rows();
121 if (selection_vector != nullptr) {
122 selection_buffer = selection_vector->GetBuffer().data();
123 num_output_rows = selection_vector->GetNumSlots();
124 }
125
126 EvalFunc jit_function = compiled_expr->GetJITFunction(mode);
127 jit_function(eval_batch->GetBufferArray(), eval_batch->GetBufferOffsetArray(),
128 eval_batch->GetLocalBitMapArray(), selection_buffer,
129 (int64_t)eval_batch->GetExecutionContext(), num_output_rows);
130
131 // check for execution errors
132 ARROW_RETURN_IF(
133 eval_batch->GetExecutionContext()->has_error(),
134 Status::ExecutionError(eval_batch->GetExecutionContext()->get_error()));
135
136 // generate validity vectors.
137 ComputeBitMapsForExpr(*compiled_expr, *eval_batch, selection_vector);
138 }
139
140 return Status::OK();
141 }
142
143 llvm::Value* LLVMGenerator::LoadVectorAtIndex(llvm::Value* arg_addrs, int idx,
144 const std::string& name) {
145 auto* idx_val = types()->i32_constant(idx);
146 auto* offset = CreateGEP(ir_builder(), arg_addrs, idx_val, name + "_mem_addr");
147 return CreateLoad(ir_builder(), offset, name + "_mem");
148 }
149
150 /// Get reference to validity array at specified index in the args list.
151 llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx,
152 FieldPtr field) {
153 const std::string& name = field->name();
154 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
155 return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), name + "_varray");
156 }
157
158 /// Get reference to data array at specified index in the args list.
159 llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx,
160 FieldPtr field) {
161 const std::string& name = field->name();
162 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
163 return ir_builder()->CreateIntToPtr(load, types()->i8_ptr_type(), name + "_buf_ptr");
164 }
165
166 /// Get reference to data array at specified index in the args list.
167 llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx,
168 FieldPtr field) {
169 const std::string& name = field->name();
170 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
171 llvm::Type* base_type = types()->DataVecType(field->type());
172 llvm::Value* ret;
173 if (base_type->isPointerTy()) {
174 ret = ir_builder()->CreateIntToPtr(load, base_type, name + "_darray");
175 } else {
176 llvm::Type* pointer_type = types()->ptr_type(base_type);
177 ret = ir_builder()->CreateIntToPtr(load, pointer_type, name + "_darray");
178 }
179 return ret;
180 }
181
182 /// Get reference to offsets array at specified index in the args list.
183 llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx,
184 FieldPtr field) {
185 const std::string& name = field->name();
186 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
187 return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray");
188 }
189
190 /// Get reference to local bitmap array at specified index in the args list.
191 llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) {
192 llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, idx, "");
193 return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(),
194 std::to_string(idx) + "_lbmap");
195 }
196
197 /// \brief Generate code for one expression.
198
199 // Sample IR code for "c1:int + c2:int"
200 //
201 // The C-code equivalent is :
202 // ------------------------------
203 // int expr_0(int64_t *addrs, int64_t *local_bitmaps,
204 // int64_t execution_context_ptr, int64_t nrecords) {
205 // int *outVec = (int *) addrs[5];
206 // int *c0Vec = (int *) addrs[1];
207 // int *c1Vec = (int *) addrs[3];
208 // for (int loop_var = 0; loop_var < nrecords; ++loop_var) {
209 // int c0 = c0Vec[loop_var];
210 // int c1 = c1Vec[loop_var];
211 // int out = c0 + c1;
212 // outVec[loop_var] = out;
213 // }
214 // }
215 //
216 // IR Code
217 // --------
218 //
219 // define i32 @expr_0(i64* %args, i64* %local_bitmaps, i64 %execution_context_ptr, , i64
220 // %nrecords) { entry:
221 // %outmemAddr = getelementptr i64, i64* %args, i32 5
222 // %outmem = load i64, i64* %outmemAddr
223 // %outVec = inttoptr i64 %outmem to i32*
224 // %c0memAddr = getelementptr i64, i64* %args, i32 1
225 // %c0mem = load i64, i64* %c0memAddr
226 // %c0Vec = inttoptr i64 %c0mem to i32*
227 // %c1memAddr = getelementptr i64, i64* %args, i32 3
228 // %c1mem = load i64, i64* %c1memAddr
229 // %c1Vec = inttoptr i64 %c1mem to i32*
230 // br label %loop
231 // loop: ; preds = %loop, %entry
232 // %loop_var = phi i64 [ 0, %entry ], [ %"loop_var+1", %loop ]
233 // %"loop_var+1" = add i64 %loop_var, 1
234 // %0 = getelementptr i32, i32* %c0Vec, i32 %loop_var
235 // %c0 = load i32, i32* %0
236 // %1 = getelementptr i32, i32* %c1Vec, i32 %loop_var
237 // %c1 = load i32, i32* %1
238 // %add_int_int = call i32 @add_int_int(i32 %c0, i32 %c1)
239 // %2 = getelementptr i32, i32* %outVec, i32 %loop_var
240 // store i32 %add_int_int, i32* %2
241 // %"loop_var < nrec" = icmp slt i64 %"loop_var+1", %nrecords
242 // br i1 %"loop_var < nrec", label %loop, label %exit
243 // exit: ; preds = %loop
244 // ret i32 0
245 // }
246 Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
247 FieldDescriptorPtr output, int suffix_idx,
248 llvm::Function** fn,
249 SelectionVector::Mode selection_vector_mode) {
250 llvm::IRBuilder<>* builder = ir_builder();
251 // Create fn prototype :
252 // int expr_1 (long **addrs, long *offsets, long **bitmaps,
253 // long *context_ptr, long nrec)
254 std::vector<llvm::Type*> arguments;
255 arguments.push_back(types()->i64_ptr_type()); // addrs
256 arguments.push_back(types()->i64_ptr_type()); // offsets
257 arguments.push_back(types()->i64_ptr_type()); // bitmaps
258 switch (selection_vector_mode) {
259 case SelectionVector::MODE_NONE:
260 case SelectionVector::MODE_UINT16:
261 arguments.push_back(types()->ptr_type(types()->i16_type()));
262 break;
263 case SelectionVector::MODE_UINT32:
264 arguments.push_back(types()->i32_ptr_type());
265 break;
266 case SelectionVector::MODE_UINT64:
267 arguments.push_back(types()->i64_ptr_type());
268 }
269 arguments.push_back(types()->i64_type()); // ctx_ptr
270 arguments.push_back(types()->i64_type()); // nrec
271 llvm::FunctionType* prototype =
272 llvm::FunctionType::get(types()->i32_type(), arguments, false /*isVarArg*/);
273
274 // Create fn
275 std::string func_name = "expr_" + std::to_string(suffix_idx) + "_" +
276 std::to_string(static_cast<int>(selection_vector_mode));
277 engine_->AddFunctionToCompile(func_name);
278 *fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, func_name,
279 module());
280 ARROW_RETURN_IF((*fn == nullptr), Status::CodeGenError("Error creating function."));
281
282 // Name the arguments
283 llvm::Function::arg_iterator args = (*fn)->arg_begin();
284 llvm::Value* arg_addrs = &*args;
285 arg_addrs->setName("inputs_addr");
286 ++args;
287 llvm::Value* arg_addr_offsets = &*args;
288 arg_addr_offsets->setName("inputs_addr_offsets");
289 ++args;
290 llvm::Value* arg_local_bitmaps = &*args;
291 arg_local_bitmaps->setName("local_bitmaps");
292 ++args;
293 llvm::Value* arg_selection_vector = &*args;
294 arg_selection_vector->setName("selection_vector");
295 ++args;
296 llvm::Value* arg_context_ptr = &*args;
297 arg_context_ptr->setName("context_ptr");
298 ++args;
299 llvm::Value* arg_nrecords = &*args;
300 arg_nrecords->setName("nrecords");
301
302 llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context(), "entry", *fn);
303 llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context(), "loop", *fn);
304 llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context(), "exit", *fn);
305
306 // Add reference to output vector (in entry block)
307 builder->SetInsertPoint(loop_entry);
308 llvm::Value* output_ref =
309 GetDataReference(arg_addrs, output->data_idx(), output->field());
310 llvm::Value* output_buffer_ptr_ref = GetDataBufferPtrReference(
311 arg_addrs, output->data_buffer_ptr_idx(), output->field());
312 llvm::Value* output_offset_ref =
313 GetOffsetsReference(arg_addrs, output->offsets_idx(), output->field());
314
315 std::vector<llvm::Value*> slice_offsets;
316 for (int idx = 0; idx < buffer_count; idx++) {
317 auto offsetAddr = CreateGEP(builder, arg_addr_offsets, types()->i32_constant(idx));
318 auto offset = CreateLoad(builder, offsetAddr);
319 slice_offsets.push_back(offset);
320 }
321
322 // Loop body
323 builder->SetInsertPoint(loop_body);
324
325 // define loop_var : start with 0, +1 after each iter
326 llvm::PHINode* loop_var = builder->CreatePHI(types()->i64_type(), 2, "loop_var");
327
328 llvm::Value* position_var = loop_var;
329 if (selection_vector_mode != SelectionVector::MODE_NONE) {
330 position_var = builder->CreateIntCast(
331 CreateLoad(builder, CreateGEP(builder, arg_selection_vector, loop_var),
332 "uncasted_position_var"),
333 types()->i64_type(), true, "position_var");
334 }
335
336 // The visitor can add code to both the entry/loop blocks.
337 Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, slice_offsets,
338 arg_context_ptr, position_var);
339 value_expr->Accept(visitor);
340 LValuePtr output_value = visitor.result();
341
342 // The "current" block may have changed due to code generation in the visitor.
343 llvm::BasicBlock* loop_body_tail = builder->GetInsertBlock();
344
345 // add jump to "loop block" at the end of the "setup block".
346 builder->SetInsertPoint(loop_entry);
347 builder->CreateBr(loop_body);
348
349 // save the value in the output vector.
350 builder->SetInsertPoint(loop_body_tail);
351
352 auto output_type_id = output->Type()->id();
353 if (output_type_id == arrow::Type::BOOL) {
354 SetPackedBitValue(output_ref, loop_var, output_value->data());
355 } else if (arrow::is_primitive(output_type_id) ||
356 output_type_id == arrow::Type::DECIMAL) {
357 llvm::Value* slot_offset = CreateGEP(builder, output_ref, loop_var);
358 builder->CreateStore(output_value->data(), slot_offset);
359 } else if (arrow::is_binary_like(output_type_id)) {
360 // Var-len output. Make a function call to populate the data.
361 // if there is an error, the fn sets it in the context. And, will be returned at the
362 // end of this row batch.
363 AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(),
364 {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var,
365 output_value->data(), output_value->length()});
366 } else {
367 return Status::NotImplemented("output type ", output->Type()->ToString(),
368 " not supported");
369 }
370 ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data());
371
372 if (visitor.has_arena_allocs()) {
373 // Reset allocations to avoid excessive memory usage. Once the result is copied to
374 // the output vector (store instruction above), any memory allocations in this
375 // iteration of the loop are no longer needed.
376 std::vector<llvm::Value*> reset_args;
377 reset_args.push_back(arg_context_ptr);
378 AddFunctionCall("gdv_fn_context_arena_reset", types()->void_type(), reset_args);
379 }
380
381 // check loop_var
382 loop_var->addIncoming(types()->i64_constant(0), loop_entry);
383 llvm::Value* loop_update =
384 builder->CreateAdd(loop_var, types()->i64_constant(1), "loop_var+1");
385 loop_var->addIncoming(loop_update, loop_body_tail);
386
387 llvm::Value* loop_var_check =
388 builder->CreateICmpSLT(loop_update, arg_nrecords, "loop_var < nrec");
389 builder->CreateCondBr(loop_var_check, loop_body, loop_exit);
390
391 // Loop exit
392 builder->SetInsertPoint(loop_exit);
393 builder->CreateRet(types()->i32_constant(0));
394 return Status::OK();
395 }
396
397 /// Return value of a bit in bitMap.
398 llvm::Value* LLVMGenerator::GetPackedBitValue(llvm::Value* bitmap,
399 llvm::Value* position) {
400 ADD_TRACE("fetch bit at position %T", position);
401
402 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
403 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
404 return AddFunctionCall("bitMapGetBit", types()->i1_type(), {bitmap8, position});
405 }
406
407 /// Set the value of a bit in bitMap.
408 void LLVMGenerator::SetPackedBitValue(llvm::Value* bitmap, llvm::Value* position,
409 llvm::Value* value) {
410 ADD_TRACE("set bit at position %T", position);
411 ADD_TRACE(" to value %T ", value);
412
413 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
414 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
415 AddFunctionCall("bitMapSetBit", types()->void_type(), {bitmap8, position, value});
416 }
417
418 /// Return value of a bit in validity bitMap (handles null bitmaps too).
419 llvm::Value* LLVMGenerator::GetPackedValidityBitValue(llvm::Value* bitmap,
420 llvm::Value* position) {
421 ADD_TRACE("fetch validity bit at position %T", position);
422
423 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
424 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
425 return AddFunctionCall("bitMapValidityGetBit", types()->i1_type(), {bitmap8, position});
426 }
427
428 /// Clear the bit in bitMap if value = false.
429 void LLVMGenerator::ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position,
430 llvm::Value* value) {
431 ADD_TRACE("ClearIfFalse bit at position %T", position);
432 ADD_TRACE(" value %T ", value);
433
434 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
435 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
436 AddFunctionCall("bitMapClearBitIfFalse", types()->void_type(),
437 {bitmap8, position, value});
438 }
439
440 /// Extract the bitmap addresses, and do an intersection.
441 void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr,
442 const EvalBatch& eval_batch,
443 const SelectionVector* selection_vector) {
444 auto validities = compiled_expr.value_validity()->validity_exprs();
445
446 // Extract all the source bitmap addresses.
447 BitMapAccumulator accumulator(eval_batch);
448 for (auto& validity_dex : validities) {
449 validity_dex->Accept(accumulator);
450 }
451
452 // Extract the destination bitmap address.
453 int out_idx = compiled_expr.output()->validity_idx();
454 uint8_t* dst_bitmap = eval_batch.GetBuffer(out_idx);
455 // Compute the destination bitmap.
456 if (selection_vector == nullptr) {
457 accumulator.ComputeResult(dst_bitmap);
458 } else {
459 /// The output bitmap is an intersection of some input/local bitmaps. However, with a
460 /// selection vector, only the bits corresponding to the indices in the selection
461 /// vector need to set in the output bitmap. This is done in two steps :
462 ///
463 /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap.
464 /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap.
465 LocalBitMapsHolder bit_map_holder(eval_batch.num_records(), 1);
466 uint8_t* temp_bitmap = bit_map_holder.GetLocalBitMap(0);
467 accumulator.ComputeResult(temp_bitmap);
468
469 auto num_out_records = selection_vector->GetNumSlots();
470 // the memset isn't required, doing it just for valgrind.
471 memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_out_records));
472 for (auto i = 0; i < num_out_records; ++i) {
473 auto bit = arrow::BitUtil::GetBit(temp_bitmap, selection_vector->GetIndex(i));
474 arrow::BitUtil::SetBitTo(dst_bitmap, i, bit);
475 }
476 }
477 }
478
479 llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,
480 llvm::Type* ret_type,
481 const std::vector<llvm::Value*>& args) {
482 // find the llvm function.
483 llvm::Function* fn = module()->getFunction(full_name);
484 DCHECK_NE(fn, nullptr) << "missing function " << full_name;
485
486 if (enable_ir_traces_ && !full_name.compare("printf") &&
487 !full_name.compare("printff")) {
488 // Trace for debugging
489 ADD_TRACE("invoke native fn " + full_name);
490 }
491
492 // build a call to the llvm function.
493 llvm::Value* value;
494 if (ret_type->isVoidTy()) {
495 // void functions can't have a name for the call.
496 value = ir_builder()->CreateCall(fn, args);
497 } else {
498 value = ir_builder()->CreateCall(fn, args, full_name);
499 DCHECK(value->getType() == ret_type);
500 }
501
502 return value;
503 }
504
505 std::shared_ptr<DecimalLValue> LLVMGenerator::BuildDecimalLValue(llvm::Value* value,
506 DataTypePtr arrow_type) {
507 // only decimals of size 128-bit supported.
508 DCHECK(is_decimal_128(arrow_type));
509 auto decimal_type =
510 arrow::internal::checked_cast<arrow::DecimalType*>(arrow_type.get());
511 return std::make_shared<DecimalLValue>(value, nullptr,
512 types()->i32_constant(decimal_type->precision()),
513 types()->i32_constant(decimal_type->scale()));
514 }
515
516 #define ADD_VISITOR_TRACE(...) \
517 if (generator_->enable_ir_traces_) { \
518 generator_->AddTrace(__VA_ARGS__); \
519 }
520
521 // Visitor for generating the code for a decomposed expression.
522 LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* function,
523 llvm::BasicBlock* entry_block, llvm::Value* arg_addrs,
524 llvm::Value* arg_local_bitmaps,
525 std::vector<llvm::Value*> slice_offsets,
526 llvm::Value* arg_context_ptr, llvm::Value* loop_var)
527 : generator_(generator),
528 function_(function),
529 entry_block_(entry_block),
530 arg_addrs_(arg_addrs),
531 arg_local_bitmaps_(arg_local_bitmaps),
532 slice_offsets_(slice_offsets),
533 arg_context_ptr_(arg_context_ptr),
534 loop_var_(loop_var),
535 has_arena_allocs_(false) {
536 ADD_VISITOR_TRACE("Iteration %T", loop_var);
537 }
538
539 void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) {
540 llvm::IRBuilder<>* builder = ir_builder();
541 llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
542 llvm::Value* slot_index = builder->CreateAdd(loop_var_, GetSliceOffset(dex.DataIdx()));
543 llvm::Value* slot_value;
544 std::shared_ptr<LValue> lvalue;
545
546 switch (dex.FieldType()->id()) {
547 case arrow::Type::BOOL:
548 slot_value = generator_->GetPackedBitValue(slot_ref, slot_index);
549 lvalue = std::make_shared<LValue>(slot_value);
550 break;
551
552 case arrow::Type::DECIMAL: {
553 auto slot_offset = CreateGEP(builder, slot_ref, slot_index);
554 slot_value = CreateLoad(builder, slot_offset, dex.FieldName());
555 lvalue = generator_->BuildDecimalLValue(slot_value, dex.FieldType());
556 break;
557 }
558
559 default: {
560 auto slot_offset = CreateGEP(builder, slot_ref, slot_index);
561 slot_value = CreateLoad(builder, slot_offset, dex.FieldName());
562 lvalue = std::make_shared<LValue>(slot_value);
563 break;
564 }
565 }
566 ADD_VISITOR_TRACE("visit fixed-len data vector " + dex.FieldName() + " value %T",
567 slot_value);
568 result_ = lvalue;
569 }
570
571 void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) {
572 llvm::IRBuilder<>* builder = ir_builder();
573 llvm::Value* slot;
574
575 // compute len from the offsets array.
576 llvm::Value* offsets_slot_ref =
577 GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field());
578 llvm::Value* offsets_slot_index =
579 builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx()));
580
581 // => offset_start = offsets[loop_var]
582 slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index);
583 llvm::Value* offset_start = CreateLoad(builder, slot, "offset_start");
584
585 // => offset_end = offsets[loop_var + 1]
586 llvm::Value* offsets_slot_index_next = builder->CreateAdd(
587 offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1");
588 slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index_next);
589 llvm::Value* offset_end = CreateLoad(builder, slot, "offset_end");
590
591 // => len_value = offset_end - offset_start
592 llvm::Value* len_value =
593 builder->CreateSub(offset_end, offset_start, dex.FieldName() + "Len");
594
595 // get the data from the data array, at offset 'offset_start'.
596 llvm::Value* data_slot_ref =
597 GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
598 llvm::Value* data_value = CreateGEP(builder, data_slot_ref, offset_start);
599 ADD_VISITOR_TRACE("visit var-len data vector " + dex.FieldName() + " len %T",
600 len_value);
601 result_.reset(new LValue(data_value, len_value));
602 }
603
604 void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) {
605 llvm::IRBuilder<>* builder = ir_builder();
606 llvm::Value* slot_ref =
607 GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field());
608 llvm::Value* slot_index =
609 builder->CreateAdd(loop_var_, GetSliceOffset(dex.ValidityIdx()));
610 llvm::Value* validity = generator_->GetPackedValidityBitValue(slot_ref, slot_index);
611
612 ADD_VISITOR_TRACE("visit validity vector " + dex.FieldName() + " value %T", validity);
613 result_.reset(new LValue(validity));
614 }
615
616 void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) {
617 llvm::Value* slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx());
618 llvm::Value* validity = generator_->GetPackedBitValue(slot_ref, loop_var_);
619
620 ADD_VISITOR_TRACE(
621 "visit local bitmap " + std::to_string(dex.local_bitmap_idx()) + " value %T",
622 validity);
623 result_.reset(new LValue(validity));
624 }
625
626 void LLVMGenerator::Visitor::Visit(const TrueDex& dex) {
627 result_.reset(new LValue(generator_->types()->true_constant()));
628 }
629
630 void LLVMGenerator::Visitor::Visit(const FalseDex& dex) {
631 result_.reset(new LValue(generator_->types()->false_constant()));
632 }
633
634 void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) {
635 LLVMTypes* types = generator_->types();
636 llvm::Value* value = nullptr;
637 llvm::Value* len = nullptr;
638
639 switch (dex.type()->id()) {
640 case arrow::Type::BOOL:
641 value = types->i1_constant(arrow::util::get<bool>(dex.holder()));
642 break;
643
644 case arrow::Type::UINT8:
645 value = types->i8_constant(arrow::util::get<uint8_t>(dex.holder()));
646 break;
647
648 case arrow::Type::UINT16:
649 value = types->i16_constant(arrow::util::get<uint16_t>(dex.holder()));
650 break;
651
652 case arrow::Type::UINT32:
653 value = types->i32_constant(arrow::util::get<uint32_t>(dex.holder()));
654 break;
655
656 case arrow::Type::UINT64:
657 value = types->i64_constant(arrow::util::get<uint64_t>(dex.holder()));
658 break;
659
660 case arrow::Type::INT8:
661 value = types->i8_constant(arrow::util::get<int8_t>(dex.holder()));
662 break;
663
664 case arrow::Type::INT16:
665 value = types->i16_constant(arrow::util::get<int16_t>(dex.holder()));
666 break;
667
668 case arrow::Type::FLOAT:
669 value = types->float_constant(arrow::util::get<float>(dex.holder()));
670 break;
671
672 case arrow::Type::DOUBLE:
673 value = types->double_constant(arrow::util::get<double>(dex.holder()));
674 break;
675
676 case arrow::Type::STRING:
677 case arrow::Type::BINARY: {
678 const std::string& str = arrow::util::get<std::string>(dex.holder());
679
680 llvm::Constant* str_int_cast = types->i64_constant((int64_t)str.c_str());
681 value = llvm::ConstantExpr::getIntToPtr(str_int_cast, types->i8_ptr_type());
682 len = types->i32_constant(static_cast<int32_t>(str.length()));
683 break;
684 }
685
686 case arrow::Type::INT32:
687 case arrow::Type::DATE32:
688 case arrow::Type::TIME32:
689 case arrow::Type::INTERVAL_MONTHS:
690 value = types->i32_constant(arrow::util::get<int32_t>(dex.holder()));
691 break;
692
693 case arrow::Type::INT64:
694 case arrow::Type::DATE64:
695 case arrow::Type::TIME64:
696 case arrow::Type::TIMESTAMP:
697 case arrow::Type::INTERVAL_DAY_TIME:
698 value = types->i64_constant(arrow::util::get<int64_t>(dex.holder()));
699 break;
700
701 case arrow::Type::DECIMAL: {
702 // build code for struct
703 auto scalar = arrow::util::get<DecimalScalar128>(dex.holder());
704 // ConstantInt doesn't have a get method that takes int128 or a pair of int64. so,
705 // passing the string representation instead.
706 auto int128_value =
707 llvm::ConstantInt::get(llvm::Type::getInt128Ty(*generator_->context()),
708 Decimal128(scalar.value()).ToIntegerString(), 10);
709 auto type = arrow::decimal(scalar.precision(), scalar.scale());
710 auto lvalue = generator_->BuildDecimalLValue(int128_value, type);
711 // set it as the l-value and return.
712 result_ = lvalue;
713 return;
714 }
715
716 default:
717 DCHECK(0);
718 }
719 ADD_VISITOR_TRACE("visit Literal %T", value);
720 result_.reset(new LValue(value, len));
721 }
722
723 void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
724 const std::string& function_name = dex.func_descriptor()->name();
725 ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name);
726
727 const NativeFunction* native_function = dex.native_function();
728
729 // build the function params (ignore validity).
730 auto params = BuildParams(dex.function_holder().get(), dex.args(), false,
731 native_function->NeedsContext());
732
733 auto arrow_return_type = dex.func_descriptor()->return_type();
734 if (native_function->CanReturnErrors()) {
735 // slow path : if a function can return errors, skip invoking the function
736 // unless all of the input args are valid. Otherwise, it can cause spurious errors.
737
738 llvm::IRBuilder<>* builder = ir_builder();
739 LLVMTypes* types = generator_->types();
740 auto arrow_type_id = arrow_return_type->id();
741 auto result_type = types->IRType(arrow_type_id);
742
743 // Build combined validity of the args.
744 llvm::Value* is_valid = types->true_constant();
745 for (auto& pair : dex.args()) {
746 auto arg_validity = BuildCombinedValidity(pair->validity_exprs());
747 is_valid = builder->CreateAnd(is_valid, arg_validity, "validityBitAnd");
748 }
749
750 // then block
751 auto then_lambda = [&] {
752 ADD_VISITOR_TRACE("fn " + function_name +
753 " can return errors : all args valid, invoke fn");
754 return BuildFunctionCall(native_function, arrow_return_type, &params);
755 };
756
757 // else block
758 auto else_lambda = [&] {
759 ADD_VISITOR_TRACE("fn " + function_name +
760 " can return errors : not all args valid, return dummy value");
761 llvm::Value* else_value = types->NullConstant(result_type);
762 llvm::Value* else_value_len = nullptr;
763 if (arrow::is_binary_like(arrow_type_id)) {
764 else_value_len = types->i32_constant(0);
765 }
766 return std::make_shared<LValue>(else_value, else_value_len);
767 };
768
769 result_ = BuildIfElse(is_valid, then_lambda, else_lambda, arrow_return_type);
770 } else {
771 // fast path : invoke function without computing validities.
772 result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
773 }
774 }
775
776 void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) {
777 ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name());
778 const NativeFunction* native_function = dex.native_function();
779
780 // build function params along with validity.
781 auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
782 native_function->NeedsContext());
783
784 auto arrow_return_type = dex.func_descriptor()->return_type();
785 result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
786 }
787
788 void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
789 ADD_VISITOR_TRACE("visit NullableInternal base function " +
790 dex.func_descriptor()->name());
791 llvm::IRBuilder<>* builder = ir_builder();
792 LLVMTypes* types = generator_->types();
793
794 const NativeFunction* native_function = dex.native_function();
795
796 // build function params along with validity.
797 auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
798 native_function->NeedsContext());
799
800 // add an extra arg for validity (allocated on stack).
801 llvm::AllocaInst* result_valid_ptr =
802 new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_);
803 params.push_back(result_valid_ptr);
804
805 auto arrow_return_type = dex.func_descriptor()->return_type();
806 result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
807
808 // load the result validity and truncate to i1.
809 llvm::Value* result_valid_i8 = CreateLoad(builder, result_valid_ptr);
810 llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type());
811
812 // set validity bit in the local bitmap.
813 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid);
814 }
815
816 void LLVMGenerator::Visitor::Visit(const IfDex& dex) {
817 ADD_VISITOR_TRACE("visit IfExpression");
818 llvm::IRBuilder<>* builder = ir_builder();
819
820 // Evaluate condition.
821 LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv());
822
823 // Check if the result is valid, and there is match.
824 llvm::Value* validAndMatched =
825 builder->CreateAnd(if_condition->data(), if_condition->validity(), "validAndMatch");
826
827 // then block
828 auto then_lambda = [&] {
829 ADD_VISITOR_TRACE("branch to then block");
830 LValuePtr then_lvalue = BuildValueAndValidity(dex.then_vv());
831 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), then_lvalue->validity());
832 ADD_VISITOR_TRACE("IfExpression result validity %T in matching then",
833 then_lvalue->validity());
834 return then_lvalue;
835 };
836
837 // else block
838 auto else_lambda = [&] {
839 LValuePtr else_lvalue;
840 if (dex.is_terminal_else()) {
841 ADD_VISITOR_TRACE("branch to terminal else block");
842
843 else_lvalue = BuildValueAndValidity(dex.else_vv());
844 // update the local bitmap with the validity.
845 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), else_lvalue->validity());
846 ADD_VISITOR_TRACE("IfExpression result validity %T in terminal else",
847 else_lvalue->validity());
848 } else {
849 ADD_VISITOR_TRACE("branch to non-terminal else block");
850
851 // this is a non-terminal else. let the child (nested if/else) handle validity.
852 auto value_expr = dex.else_vv().value_expr();
853 value_expr->Accept(*this);
854 else_lvalue = result();
855 }
856 return else_lvalue;
857 };
858
859 // build the if-else condition.
860 result_ = BuildIfElse(validAndMatched, then_lambda, else_lambda, dex.result_type());
861 if (arrow::is_binary_like(dex.result_type()->id())) {
862 ADD_VISITOR_TRACE("IfElse result length %T", result_->length());
863 }
864 ADD_VISITOR_TRACE("IfElse result value %T", result_->data());
865 }
866
867 // Boolean AND
868 // if any arg is valid and false,
869 // short-circuit and return FALSE (value=false, valid=true)
870 // else if all args are valid and true
871 // return TRUE (value=true, valid=true)
872 // else
873 // return NULL (value=true, valid=false)
874
875 void LLVMGenerator::Visitor::Visit(const BooleanAndDex& dex) {
876 ADD_VISITOR_TRACE("visit BooleanAndExpression");
877 llvm::IRBuilder<>* builder = ir_builder();
878 LLVMTypes* types = generator_->types();
879 llvm::LLVMContext* context = generator_->context();
880
881 // Create blocks for short-circuit.
882 llvm::BasicBlock* short_circuit_bb =
883 llvm::BasicBlock::Create(*context, "short_circuit", function_);
884 llvm::BasicBlock* non_short_circuit_bb =
885 llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
886 llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
887
888 llvm::Value* all_exprs_valid = types->true_constant();
889 for (auto& pair : dex.args()) {
890 LValuePtr current = BuildValueAndValidity(*pair);
891
892 ADD_VISITOR_TRACE("BooleanAndExpression arg value %T", current->data());
893 ADD_VISITOR_TRACE("BooleanAndExpression arg validity %T", current->validity());
894
895 // short-circuit if valid and false
896 llvm::Value* is_false = builder->CreateNot(current->data());
897 llvm::Value* valid_and_false =
898 builder->CreateAnd(is_false, current->validity(), "valid_and_false");
899
900 llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
901 builder->CreateCondBr(valid_and_false, short_circuit_bb, else_bb);
902
903 // Emit the else block.
904 builder->SetInsertPoint(else_bb);
905 // remember if any nulls were encountered.
906 all_exprs_valid =
907 builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
908 // continue to evaluate the next pair in list.
909 }
910 builder->CreateBr(non_short_circuit_bb);
911
912 // Short-circuit case (at least one of the expressions is valid and false).
913 // No need to set validity bit (valid by default).
914 builder->SetInsertPoint(short_circuit_bb);
915 ADD_VISITOR_TRACE("BooleanAndExpression result value false");
916 ADD_VISITOR_TRACE("BooleanAndExpression result validity true");
917 builder->CreateBr(merge_bb);
918
919 // non short-circuit case (All expressions are either true or null).
920 // result valid if all of the exprs are non-null.
921 builder->SetInsertPoint(non_short_circuit_bb);
922 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
923 ADD_VISITOR_TRACE("BooleanAndExpression result value true");
924 ADD_VISITOR_TRACE("BooleanAndExpression result validity %T", all_exprs_valid);
925 builder->CreateBr(merge_bb);
926
927 builder->SetInsertPoint(merge_bb);
928 llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
929 result_value->addIncoming(types->false_constant(), short_circuit_bb);
930 result_value->addIncoming(types->true_constant(), non_short_circuit_bb);
931 result_.reset(new LValue(result_value));
932 }
933
934 // Boolean OR
935 // if any arg is valid and true,
936 // short-circuit and return TRUE (value=true, valid=true)
937 // else if all args are valid and false
938 // return FALSE (value=false, valid=true)
939 // else
940 // return NULL (value=false, valid=false)
941
942 void LLVMGenerator::Visitor::Visit(const BooleanOrDex& dex) {
943 ADD_VISITOR_TRACE("visit BooleanOrExpression");
944 llvm::IRBuilder<>* builder = ir_builder();
945 LLVMTypes* types = generator_->types();
946 llvm::LLVMContext* context = generator_->context();
947
948 // Create blocks for short-circuit.
949 llvm::BasicBlock* short_circuit_bb =
950 llvm::BasicBlock::Create(*context, "short_circuit", function_);
951 llvm::BasicBlock* non_short_circuit_bb =
952 llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
953 llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
954
955 llvm::Value* all_exprs_valid = types->true_constant();
956 for (auto& pair : dex.args()) {
957 LValuePtr current = BuildValueAndValidity(*pair);
958
959 ADD_VISITOR_TRACE("BooleanOrExpression arg value %T", current->data());
960 ADD_VISITOR_TRACE("BooleanOrExpression arg validity %T", current->validity());
961
962 // short-circuit if valid and true.
963 llvm::Value* valid_and_true =
964 builder->CreateAnd(current->data(), current->validity(), "valid_and_true");
965
966 llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
967 builder->CreateCondBr(valid_and_true, short_circuit_bb, else_bb);
968
969 // Emit the else block.
970 builder->SetInsertPoint(else_bb);
971 // remember if any nulls were encountered.
972 all_exprs_valid =
973 builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
974 // continue to evaluate the next pair in list.
975 }
976 builder->CreateBr(non_short_circuit_bb);
977
978 // Short-circuit case (at least one of the expressions is valid and true).
979 // No need to set validity bit (valid by default).
980 builder->SetInsertPoint(short_circuit_bb);
981 ADD_VISITOR_TRACE("BooleanOrExpression result value true");
982 ADD_VISITOR_TRACE("BooleanOrExpression result validity true");
983 builder->CreateBr(merge_bb);
984
985 // non short-circuit case (All expressions are either false or null).
986 // result valid if all of the exprs are non-null.
987 builder->SetInsertPoint(non_short_circuit_bb);
988 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
989 ADD_VISITOR_TRACE("BooleanOrExpression result value false");
990 ADD_VISITOR_TRACE("BooleanOrExpression result validity %T", all_exprs_valid);
991 builder->CreateBr(merge_bb);
992
993 builder->SetInsertPoint(merge_bb);
994 llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
995 result_value->addIncoming(types->true_constant(), short_circuit_bb);
996 result_value->addIncoming(types->false_constant(), non_short_circuit_bb);
997 result_.reset(new LValue(result_value));
998 }
999
1000 template <typename Type>
1001 void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase<Type>& dex) {
1002 ADD_VISITOR_TRACE("visit In Expression");
1003 LLVMTypes* types = generator_->types();
1004 std::vector<llvm::Value*> params;
1005
1006 const InExprDex<Type>& dex_instance = dynamic_cast<const InExprDex<Type>&>(dex);
1007 /* add the holder at the beginning */
1008 llvm::Constant* ptr_int_cast =
1009 types->i64_constant((int64_t)(dex_instance.in_holder().get()));
1010 params.push_back(ptr_int_cast);
1011
1012 /* eval expr result */
1013 for (auto& pair : dex.args()) {
1014 DexPtr value_expr = pair->value_expr();
1015 value_expr->Accept(*this);
1016 LValue& result_ref = *result();
1017 params.push_back(result_ref.data());
1018
1019 /* length if the result is a string */
1020 if (result_ref.length() != nullptr) {
1021 params.push_back(result_ref.length());
1022 }
1023
1024 /* push the validity of eval expr result */
1025 llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
1026 params.push_back(validity_expr);
1027 }
1028
1029 llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL);
1030
1031 llvm::Value* value;
1032
1033 value = generator_->AddFunctionCall(dex.runtime_function(), ret_type, params);
1034
1035 result_.reset(new LValue(value));
1036 }
1037
1038 template <>
1039 void LLVMGenerator::Visitor::VisitInExpression<gandiva::DecimalScalar128>(
1040 const InExprDexBase<gandiva::DecimalScalar128>& dex) {
1041 ADD_VISITOR_TRACE("visit In Expression");
1042 LLVMTypes* types = generator_->types();
1043 std::vector<llvm::Value*> params;
1044 DecimalIR decimalIR(generator_->engine_.get());
1045
1046 const InExprDex<gandiva::DecimalScalar128>& dex_instance =
1047 dynamic_cast<const InExprDex<gandiva::DecimalScalar128>&>(dex);
1048 /* add the holder at the beginning */
1049 llvm::Constant* ptr_int_cast =
1050 types->i64_constant((int64_t)(dex_instance.in_holder().get()));
1051 params.push_back(ptr_int_cast);
1052
1053 /* eval expr result */
1054 for (auto& pair : dex.args()) {
1055 DexPtr value_expr = pair->value_expr();
1056 value_expr->Accept(*this);
1057 LValue& result_ref = *result();
1058 params.push_back(result_ref.data());
1059
1060 llvm::Constant* precision = types->i32_constant(dex.get_precision());
1061 llvm::Constant* scale = types->i32_constant(dex.get_scale());
1062 params.push_back(precision);
1063 params.push_back(scale);
1064
1065 /* push the validity of eval expr result */
1066 llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
1067 params.push_back(validity_expr);
1068 }
1069
1070 llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL);
1071
1072 llvm::Value* value;
1073
1074 value = decimalIR.CallDecimalFunction(dex.runtime_function(), ret_type, params);
1075
1076 result_.reset(new LValue(value));
1077 }
1078
1079 void LLVMGenerator::Visitor::Visit(const InExprDexBase<int32_t>& dex) {
1080 VisitInExpression<int32_t>(dex);
1081 }
1082
1083 void LLVMGenerator::Visitor::Visit(const InExprDexBase<int64_t>& dex) {
1084 VisitInExpression<int64_t>(dex);
1085 }
1086
1087 void LLVMGenerator::Visitor::Visit(const InExprDexBase<float>& dex) {
1088 VisitInExpression<float>(dex);
1089 }
1090 void LLVMGenerator::Visitor::Visit(const InExprDexBase<double>& dex) {
1091 VisitInExpression<double>(dex);
1092 }
1093
1094 void LLVMGenerator::Visitor::Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) {
1095 VisitInExpression<gandiva::DecimalScalar128>(dex);
1096 }
1097
1098 void LLVMGenerator::Visitor::Visit(const InExprDexBase<std::string>& dex) {
1099 VisitInExpression<std::string>(dex);
1100 }
1101
1102 LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition,
1103 std::function<LValuePtr()> then_func,
1104 std::function<LValuePtr()> else_func,
1105 DataTypePtr result_type) {
1106 llvm::IRBuilder<>* builder = ir_builder();
1107 llvm::LLVMContext* context = generator_->context();
1108 LLVMTypes* types = generator_->types();
1109
1110 // Create blocks for the then, else and merge cases.
1111 llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context, "then", function_);
1112 llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
1113 llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
1114
1115 builder->CreateCondBr(condition, then_bb, else_bb);
1116
1117 // Emit the then block.
1118 builder->SetInsertPoint(then_bb);
1119 LValuePtr then_lvalue = then_func();
1120 builder->CreateBr(merge_bb);
1121
1122 // refresh then_bb for phi (could have changed due to code generation of then_vv).
1123 then_bb = builder->GetInsertBlock();
1124
1125 // Emit the else block.
1126 builder->SetInsertPoint(else_bb);
1127 LValuePtr else_lvalue = else_func();
1128 builder->CreateBr(merge_bb);
1129
1130 // refresh else_bb for phi (could have changed due to code generation of else_vv).
1131 else_bb = builder->GetInsertBlock();
1132
1133 // Emit the merge block.
1134 builder->SetInsertPoint(merge_bb);
1135 auto llvm_type = types->IRType(result_type->id());
1136 llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value");
1137 result_value->addIncoming(then_lvalue->data(), then_bb);
1138 result_value->addIncoming(else_lvalue->data(), else_bb);
1139
1140 LValuePtr ret;
1141 switch (result_type->id()) {
1142 case arrow::Type::STRING:
1143 case arrow::Type::BINARY: {
1144 llvm::PHINode* result_length;
1145 result_length = builder->CreatePHI(types->i32_type(), 2, "res_length");
1146 result_length->addIncoming(then_lvalue->length(), then_bb);
1147 result_length->addIncoming(else_lvalue->length(), else_bb);
1148 ret = std::make_shared<LValue>(result_value, result_length);
1149 break;
1150 }
1151
1152 case arrow::Type::DECIMAL:
1153 ret = generator_->BuildDecimalLValue(result_value, result_type);
1154 break;
1155
1156 default:
1157 ret = std::make_shared<LValue>(result_value);
1158 break;
1159 }
1160 return ret;
1161 }
1162
1163 LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair& pair) {
1164 // generate code for value
1165 auto value_expr = pair.value_expr();
1166 value_expr->Accept(*this);
1167 auto value = result()->data();
1168 auto length = result()->length();
1169
1170 // generate code for validity
1171 auto validity = BuildCombinedValidity(pair.validity_exprs());
1172
1173 return std::make_shared<LValue>(value, length, validity);
1174 }
1175
1176 LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
1177 DataTypePtr arrow_return_type,
1178 std::vector<llvm::Value*>* params) {
1179 auto types = generator_->types();
1180 auto arrow_return_type_id = arrow_return_type->id();
1181 auto llvm_return_type = types->IRType(arrow_return_type_id);
1182 DecimalIR decimalIR(generator_->engine_.get());
1183
1184 if (arrow_return_type_id == arrow::Type::DECIMAL) {
1185 // For decimal fns, the output precision/scale are passed along as parameters.
1186 //
1187 // convert from this :
1188 // out = add_decimal(v1, p1, s1, v2, p2, s2)
1189 // to:
1190 // out = add_decimal(v1, p1, s1, v2, p2, s2, out_p, out_s)
1191
1192 // Append the out_precision and out_scale
1193 auto ret_lvalue = generator_->BuildDecimalLValue(nullptr, arrow_return_type);
1194 params->push_back(ret_lvalue->precision());
1195 params->push_back(ret_lvalue->scale());
1196
1197 // Make the function call
1198 auto out = decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params);
1199 ret_lvalue->set_data(out);
1200 return std::move(ret_lvalue);
1201 } else {
1202 bool isDecimalFunction = false;
1203 for (auto& arg : *params) {
1204 if (arg->getType() == types->i128_type()) {
1205 isDecimalFunction = true;
1206 }
1207 }
1208 // add extra arg for return length for variable len return types (allocated on stack).
1209 llvm::AllocaInst* result_len_ptr = nullptr;
1210 if (arrow::is_binary_like(arrow_return_type_id)) {
1211 result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
1212 "result_len", entry_block_);
1213 params->push_back(result_len_ptr);
1214 has_arena_allocs_ = true;
1215 }
1216
1217 // Make the function call
1218 llvm::IRBuilder<>* builder = ir_builder();
1219 auto value =
1220 isDecimalFunction
1221 ? decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params)
1222 : generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
1223 auto value_len =
1224 (result_len_ptr == nullptr) ? nullptr : CreateLoad(builder, result_len_ptr);
1225 return std::make_shared<LValue>(value, value_len);
1226 }
1227 }
1228
1229 std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams(
1230 FunctionHolder* holder, const ValueValidityPairVector& args, bool with_validity,
1231 bool with_context) {
1232 LLVMTypes* types = generator_->types();
1233 std::vector<llvm::Value*> params;
1234
1235 // add context if required.
1236 if (with_context) {
1237 params.push_back(arg_context_ptr_);
1238 }
1239
1240 // if the function has holder, add the holder pointer.
1241 if (holder != nullptr) {
1242 auto ptr = types->i64_constant((int64_t)holder);
1243 params.push_back(ptr);
1244 }
1245
1246 // build the function params, along with the validities.
1247 for (auto& pair : args) {
1248 // build value.
1249 DexPtr value_expr = pair->value_expr();
1250 value_expr->Accept(*this);
1251 LValue& result_ref = *result();
1252
1253 // append all the parameters corresponding to this LValue.
1254 result_ref.AppendFunctionParams(&params);
1255
1256 // build validity.
1257 if (with_validity) {
1258 llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
1259 params.push_back(validity_expr);
1260 }
1261 }
1262
1263 return params;
1264 }
1265
1266 // Bitwise-AND of a vector of bits to get the combined validity.
1267 llvm::Value* LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector& validities) {
1268 llvm::IRBuilder<>* builder = ir_builder();
1269 LLVMTypes* types = generator_->types();
1270
1271 llvm::Value* isValid = types->true_constant();
1272 for (auto& dex : validities) {
1273 dex->Accept(*this);
1274 isValid = builder->CreateAnd(isValid, result()->data(), "validityBitAnd");
1275 }
1276 ADD_VISITOR_TRACE("combined validity is %T", isValid);
1277 return isValid;
1278 }
1279
1280 llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buffer_type,
1281 FieldPtr field) {
1282 llvm::IRBuilder<>* builder = ir_builder();
1283
1284 // Switch to the entry block to create a reference.
1285 llvm::BasicBlock* saved_block = builder->GetInsertBlock();
1286 builder->SetInsertPoint(entry_block_);
1287
1288 llvm::Value* slot_ref = nullptr;
1289 switch (buffer_type) {
1290 case kBufferTypeValidity:
1291 slot_ref = generator_->GetValidityReference(arg_addrs_, idx, field);
1292 break;
1293
1294 case kBufferTypeData:
1295 slot_ref = generator_->GetDataReference(arg_addrs_, idx, field);
1296 break;
1297
1298 case kBufferTypeOffsets:
1299 slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field);
1300 break;
1301 }
1302
1303 // Revert to the saved block.
1304 builder->SetInsertPoint(saved_block);
1305 return slot_ref;
1306 }
1307
1308 llvm::Value* LLVMGenerator::Visitor::GetSliceOffset(int idx) {
1309 return slice_offsets_[idx];
1310 }
1311
1312 llvm::Value* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) {
1313 llvm::IRBuilder<>* builder = ir_builder();
1314
1315 // Switch to the entry block to create a reference.
1316 llvm::BasicBlock* saved_block = builder->GetInsertBlock();
1317 builder->SetInsertPoint(entry_block_);
1318
1319 llvm::Value* slot_ref = generator_->GetLocalBitMapReference(arg_local_bitmaps_, idx);
1320
1321 // Revert to the saved block.
1322 builder->SetInsertPoint(saved_block);
1323 return slot_ref;
1324 }
1325
1326 /// The local bitmap is pre-filled with 1s. Clear only if invalid.
1327 void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx,
1328 llvm::Value* is_valid) {
1329 llvm::Value* slot_ref = GetLocalBitMapReference(local_bitmap_idx);
1330 generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid);
1331 }
1332
1333 // Hooks for tracing/printfs.
1334 //
1335 // replace %T with the type-specific format specifier.
1336 // For some reason, float/double literals are getting lost when printing with the generic
1337 // printf. so, use a wrapper instead.
1338 std::string LLVMGenerator::ReplaceFormatInTrace(const std::string& in_msg,
1339 llvm::Value* value,
1340 std::string* print_fn) {
1341 std::string msg = in_msg;
1342 std::size_t pos = msg.find("%T");
1343 if (pos == std::string::npos) {
1344 DCHECK(0);
1345 return msg;
1346 }
1347
1348 llvm::Type* type = value->getType();
1349 const char* fmt = "";
1350 if (type->isIntegerTy(1) || type->isIntegerTy(8) || type->isIntegerTy(16) ||
1351 type->isIntegerTy(32)) {
1352 fmt = "%d";
1353 } else if (type->isIntegerTy(64)) {
1354 // bigint
1355 fmt = "%lld";
1356 } else if (type->isFloatTy()) {
1357 // float
1358 fmt = "%f";
1359 *print_fn = "print_float";
1360 } else if (type->isDoubleTy()) {
1361 // float
1362 fmt = "%lf";
1363 *print_fn = "print_double";
1364 } else if (type->isPointerTy()) {
1365 // string
1366 fmt = "%s";
1367 } else {
1368 DCHECK(0);
1369 }
1370 msg.replace(pos, 2, fmt);
1371 return msg;
1372 }
1373
1374 void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) {
1375 if (!enable_ir_traces_) {
1376 return;
1377 }
1378
1379 std::string dmsg = "IR_TRACE:: " + msg + "\n";
1380 std::string print_fn_name = "printf";
1381 if (value != nullptr) {
1382 dmsg = ReplaceFormatInTrace(dmsg, value, &print_fn_name);
1383 }
1384 trace_strings_.push_back(dmsg);
1385
1386 // cast this to an llvm pointer.
1387 const char* str = trace_strings_.back().c_str();
1388 llvm::Constant* str_int_cast = types()->i64_constant((int64_t)str);
1389 llvm::Constant* str_ptr_cast =
1390 llvm::ConstantExpr::getIntToPtr(str_int_cast, types()->i8_ptr_type());
1391
1392 std::vector<llvm::Value*> args;
1393 args.push_back(str_ptr_cast);
1394 if (value != nullptr) {
1395 args.push_back(value);
1396 }
1397 AddFunctionCall(print_fn_name, types()->i32_type(), args);
1398 }
1399
1400 } // namespace gandiva