]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one or more | |
3 | * contributor license agreements. See the NOTICE file distributed with | |
4 | * this work for additional information regarding copyright ownership. | |
5 | * The ASF licenses this file to You under the Apache License, Version 2.0 | |
6 | * (the "License"); you may not use this file except in compliance with | |
7 | * the License. You may obtain a copy of the License at | |
8 | * | |
9 | * http://www.apache.org/licenses/LICENSE-2.0 | |
10 | * | |
11 | * Unless required by applicable law or agreed to in writing, software | |
12 | * distributed under the License is distributed on an "AS IS" BASIS, | |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 | * See the License for the specific language governing permissions and | |
15 | * limitations under the License. | |
16 | */ | |
17 | ||
18 | package org.apache.arrow.vector.dictionary; | |
19 | ||
20 | import org.apache.arrow.memory.BufferAllocator; | |
21 | import org.apache.arrow.memory.util.hash.ArrowBufHasher; | |
22 | import org.apache.arrow.memory.util.hash.SimpleHasher; | |
23 | import org.apache.arrow.util.Preconditions; | |
24 | import org.apache.arrow.vector.BaseIntVector; | |
25 | import org.apache.arrow.vector.FieldVector; | |
26 | import org.apache.arrow.vector.ValueVector; | |
27 | import org.apache.arrow.vector.types.pojo.ArrowType; | |
28 | import org.apache.arrow.vector.types.pojo.Field; | |
29 | import org.apache.arrow.vector.types.pojo.FieldType; | |
30 | import org.apache.arrow.vector.util.TransferPair; | |
31 | ||
32 | /** | |
33 | * Encoder/decoder for Dictionary encoded {@link ValueVector}. Dictionary encoding produces an | |
34 | * integer {@link ValueVector}. Each entry in the Vector is index into the dictionary which can hold | |
35 | * values of any type. | |
36 | */ | |
37 | public class DictionaryEncoder { | |
38 | ||
39 | private final DictionaryHashTable hashTable; | |
40 | private final Dictionary dictionary; | |
41 | private final BufferAllocator allocator; | |
42 | ||
43 | /** | |
44 | * Construct an instance. | |
45 | */ | |
46 | public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator) { | |
47 | this (dictionary, allocator, SimpleHasher.INSTANCE); | |
48 | } | |
49 | ||
50 | /** | |
51 | * Construct an instance. | |
52 | */ | |
53 | public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator, ArrowBufHasher hasher) { | |
54 | this.dictionary = dictionary; | |
55 | this.allocator = allocator; | |
56 | hashTable = new DictionaryHashTable(dictionary.getVector(), hasher); | |
57 | } | |
58 | ||
59 | /** | |
60 | * Dictionary encodes a vector with a provided dictionary. The dictionary must contain all values in the vector. | |
61 | * | |
62 | * @param vector vector to encode | |
63 | * @param dictionary dictionary used for encoding | |
64 | * @return dictionary encoded vector | |
65 | */ | |
66 | public static ValueVector encode(ValueVector vector, Dictionary dictionary) { | |
67 | DictionaryEncoder encoder = new DictionaryEncoder(dictionary, vector.getAllocator()); | |
68 | return encoder.encode(vector); | |
69 | } | |
70 | ||
71 | /** | |
72 | * Decodes a dictionary encoded array using the provided dictionary. | |
73 | * | |
74 | * @param indices dictionary encoded values, must be int type | |
75 | * @param dictionary dictionary used to decode the values | |
76 | * @return vector with values restored from dictionary | |
77 | */ | |
78 | public static ValueVector decode(ValueVector indices, Dictionary dictionary) { | |
79 | DictionaryEncoder encoder = new DictionaryEncoder(dictionary, indices.getAllocator()); | |
80 | return encoder.decode(indices); | |
81 | } | |
82 | ||
83 | /** | |
84 | * Get the indexType according to the dictionary vector valueCount. | |
85 | * @param valueCount dictionary vector valueCount. | |
86 | * @return index type. | |
87 | */ | |
88 | public static ArrowType.Int getIndexType(int valueCount) { | |
89 | Preconditions.checkArgument(valueCount >= 0); | |
90 | if (valueCount <= Byte.MAX_VALUE) { | |
91 | return new ArrowType.Int(8, true); | |
92 | } else if (valueCount <= Character.MAX_VALUE) { | |
93 | return new ArrowType.Int(16, true); | |
94 | } else if (valueCount <= Integer.MAX_VALUE) { | |
95 | return new ArrowType.Int(32, true); | |
96 | } else { | |
97 | return new ArrowType.Int(64, true); | |
98 | } | |
99 | } | |
100 | ||
101 | /** | |
102 | * Populates indices between start and end with the encoded values of vector. | |
103 | * @param vector the vector to encode | |
104 | * @param indices the index vector | |
105 | * @param encoding the hash table for encoding | |
106 | * @param start the start index | |
107 | * @param end the end index | |
108 | */ | |
109 | static void buildIndexVector( | |
110 | ValueVector vector, | |
111 | BaseIntVector indices, | |
112 | DictionaryHashTable encoding, | |
113 | int start, | |
114 | int end) { | |
115 | ||
116 | for (int i = start; i < end; i++) { | |
117 | if (!vector.isNull(i)) { | |
118 | // if it's null leave it null | |
119 | // note: this may fail if value was not included in the dictionary | |
120 | int encoded = encoding.getIndex(i, vector); | |
121 | if (encoded == -1) { | |
122 | throw new IllegalArgumentException("Dictionary encoding not defined for value:" + vector.getObject(i)); | |
123 | } | |
124 | indices.setWithPossibleTruncate(i, encoded); | |
125 | } | |
126 | } | |
127 | } | |
128 | ||
129 | /** | |
130 | * Retrieve values to target vector from index vector. | |
131 | * @param indices the index vector | |
132 | * @param transfer the {@link TransferPair} to copy dictionary data into target vector. | |
133 | * @param dictionaryCount the value count of dictionary vector. | |
134 | * @param start the start index | |
135 | * @param end the end index | |
136 | */ | |
137 | static void retrieveIndexVector( | |
138 | BaseIntVector indices, | |
139 | TransferPair transfer, | |
140 | int dictionaryCount, | |
141 | int start, | |
142 | int end) { | |
143 | for (int i = start; i < end; i++) { | |
144 | if (!indices.isNull(i)) { | |
145 | int indexAsInt = (int) indices.getValueAsLong(i); | |
146 | if (indexAsInt > dictionaryCount) { | |
147 | throw new IllegalArgumentException("Provided dictionary does not contain value for index " + indexAsInt); | |
148 | } | |
149 | transfer.copyValueSafe(indexAsInt, i); | |
150 | } | |
151 | } | |
152 | } | |
153 | ||
154 | /** | |
155 | * Encodes a vector with the built hash table in this encoder. | |
156 | */ | |
157 | public ValueVector encode(ValueVector vector) { | |
158 | ||
159 | Field valueField = vector.getField(); | |
160 | FieldType indexFieldType = new FieldType(valueField.isNullable(), dictionary.getEncoding().getIndexType(), | |
161 | dictionary.getEncoding(), valueField.getMetadata()); | |
162 | Field indexField = new Field(valueField.getName(), indexFieldType, null); | |
163 | ||
164 | // vector to hold our indices (dictionary encoded values) | |
165 | FieldVector createdVector = indexField.createVector(allocator); | |
166 | if (! (createdVector instanceof BaseIntVector)) { | |
167 | throw new IllegalArgumentException("Dictionary encoding does not have a valid int type:" + | |
168 | createdVector.getClass()); | |
169 | } | |
170 | ||
171 | BaseIntVector indices = (BaseIntVector) createdVector; | |
172 | indices.allocateNew(); | |
173 | ||
174 | buildIndexVector(vector, indices, hashTable, 0, vector.getValueCount()); | |
175 | indices.setValueCount(vector.getValueCount()); | |
176 | return indices; | |
177 | } | |
178 | ||
179 | /** | |
180 | * Decodes a vector with the built hash table in this encoder. | |
181 | */ | |
182 | public ValueVector decode(ValueVector indices) { | |
183 | int count = indices.getValueCount(); | |
184 | ValueVector dictionaryVector = dictionary.getVector(); | |
185 | int dictionaryCount = dictionaryVector.getValueCount(); | |
186 | // copy the dictionary values into the decoded vector | |
187 | TransferPair transfer = dictionaryVector.getTransferPair(allocator); | |
188 | transfer.getTo().allocateNewSafe(); | |
189 | ||
190 | BaseIntVector baseIntVector = (BaseIntVector) indices; | |
191 | retrieveIndexVector(baseIntVector, transfer, dictionaryCount, 0, count); | |
192 | ValueVector decoded = transfer.getTo(); | |
193 | decoded.setValueCount(count); | |
194 | return decoded; | |
195 | } | |
196 | } |