]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, software | |
12 | // distributed under the License is distributed on an "AS IS" BASIS, | |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 | // See the License for the specific language governing permissions and | |
15 | // limitations under the License. | |
16 | ||
17 | package scalar | |
18 | ||
19 | import ( | |
20 | "bytes" | |
21 | "fmt" | |
22 | ||
23 | "github.com/apache/arrow/go/v6/arrow" | |
24 | "github.com/apache/arrow/go/v6/arrow/array" | |
25 | "github.com/apache/arrow/go/v6/arrow/internal/debug" | |
26 | "github.com/apache/arrow/go/v6/arrow/memory" | |
27 | "golang.org/x/xerrors" | |
28 | ) | |
29 | ||
30 | type ListScalar interface { | |
31 | Scalar | |
32 | GetList() array.Interface | |
33 | Release() | |
34 | Retain() | |
35 | } | |
36 | ||
37 | type List struct { | |
38 | scalar | |
39 | Value array.Interface | |
40 | } | |
41 | ||
42 | func (l *List) Release() { l.Value.Release() } | |
43 | func (l *List) Retain() { l.Value.Retain() } | |
44 | func (l *List) value() interface{} { return l.Value } | |
45 | func (l *List) GetList() array.Interface { return l.Value } | |
46 | func (l *List) equals(rhs Scalar) bool { | |
47 | return array.ArrayEqual(l.Value, rhs.(ListScalar).GetList()) | |
48 | } | |
49 | func (l *List) Validate() (err error) { | |
50 | if err = l.scalar.Validate(); err != nil { | |
51 | return | |
52 | } | |
53 | if err = validateOptional(&l.scalar, l.Value, "value"); err != nil { | |
54 | return | |
55 | } | |
56 | ||
57 | if !l.Valid { | |
58 | return | |
59 | } | |
60 | ||
61 | var ( | |
62 | valueType arrow.DataType | |
63 | ) | |
64 | ||
65 | switch dt := l.Type.(type) { | |
66 | case *arrow.ListType: | |
67 | valueType = dt.Elem() | |
68 | case *arrow.FixedSizeListType: | |
69 | valueType = dt.Elem() | |
70 | case *arrow.MapType: | |
71 | valueType = dt.ValueType() | |
72 | } | |
73 | listType := l.Type | |
74 | ||
75 | if !arrow.TypeEqual(l.Value.DataType(), valueType) { | |
76 | err = xerrors.Errorf("%s scalar should have a value of type %s, got %s", | |
77 | listType, valueType, l.Value.DataType()) | |
78 | } | |
79 | return | |
80 | } | |
81 | ||
82 | func (l *List) ValidateFull() error { return l.Validate() } | |
83 | func (l *List) CastTo(to arrow.DataType) (Scalar, error) { | |
84 | if !l.Valid { | |
85 | return MakeNullScalar(to), nil | |
86 | } | |
87 | ||
88 | if arrow.TypeEqual(l.Type, to) { | |
89 | return l, nil | |
90 | } | |
91 | ||
92 | if to.ID() == arrow.STRING { | |
93 | var bld bytes.Buffer | |
94 | fmt.Fprint(&bld, l.Value) | |
95 | buf := memory.NewBufferBytes(bld.Bytes()) | |
96 | defer buf.Release() | |
97 | return NewStringScalarFromBuffer(buf), nil | |
98 | } | |
99 | ||
100 | return nil, xerrors.Errorf("cannot convert non-nil list scalar to type %s", to) | |
101 | } | |
102 | ||
103 | func (l *List) String() string { | |
104 | if !l.Valid { | |
105 | return "null" | |
106 | } | |
107 | val, err := l.CastTo(arrow.BinaryTypes.String) | |
108 | if err != nil { | |
109 | return "..." | |
110 | } | |
111 | return string(val.(*String).Value.Bytes()) | |
112 | } | |
113 | ||
114 | func NewListScalar(val array.Interface) *List { | |
115 | return &List{scalar{arrow.ListOf(val.DataType()), true}, array.MakeFromData(val.Data())} | |
116 | } | |
117 | ||
118 | func makeMapType(typ *arrow.StructType) *arrow.MapType { | |
119 | debug.Assert(len(typ.Fields()) == 2, "must pass struct with only 2 fields for MapScalar") | |
120 | return arrow.MapOf(typ.Field(0).Type, typ.Field(1).Type) | |
121 | } | |
122 | ||
123 | type Map struct { | |
124 | *List | |
125 | } | |
126 | ||
127 | func NewMapScalar(val array.Interface) *Map { | |
128 | return &Map{&List{scalar{makeMapType(val.DataType().(*arrow.StructType)), true}, array.MakeFromData(val.Data())}} | |
129 | } | |
130 | ||
131 | type FixedSizeList struct { | |
132 | *List | |
133 | } | |
134 | ||
135 | func (f *FixedSizeList) Validate() (err error) { | |
136 | if err = f.List.Validate(); err != nil { | |
137 | return | |
138 | } | |
139 | ||
140 | if f.Valid { | |
141 | listType := f.Type.(*arrow.FixedSizeListType) | |
142 | if f.Value.Len() != int(listType.Len()) { | |
143 | return xerrors.Errorf("%s scalar should have a child value of length %d, got %d", | |
144 | f.Type, listType.Len(), f.Value.Len()) | |
145 | } | |
146 | } | |
147 | return | |
148 | } | |
149 | ||
150 | func (f *FixedSizeList) ValidateFull() error { return f.Validate() } | |
151 | ||
152 | func NewFixedSizeListScalar(val array.Interface) *FixedSizeList { | |
153 | return NewFixedSizeListScalarWithType(val, arrow.FixedSizeListOf(int32(val.Len()), val.DataType())) | |
154 | } | |
155 | ||
156 | func NewFixedSizeListScalarWithType(val array.Interface, typ arrow.DataType) *FixedSizeList { | |
157 | debug.Assert(val.Len() == int(typ.(*arrow.FixedSizeListType).Len()), "length of value for fixed size list scalar must match type") | |
158 | return &FixedSizeList{&List{scalar{typ, true}, array.MakeFromData(val.Data())}} | |
159 | } | |
160 | ||
161 | type Vector []Scalar | |
162 | ||
163 | type Struct struct { | |
164 | scalar | |
165 | Value Vector | |
166 | } | |
167 | ||
168 | func (s *Struct) Field(name string) (Scalar, error) { | |
169 | idx, ok := s.Type.(*arrow.StructType).FieldIdx(name) | |
170 | if !ok { | |
171 | return nil, xerrors.Errorf("no field named %s found in struct scalar %s", name, s.Type) | |
172 | } | |
173 | ||
174 | return s.Value[idx], nil | |
175 | } | |
176 | ||
177 | func (s *Struct) value() interface{} { return s.Value } | |
178 | ||
179 | func (s *Struct) String() string { | |
180 | if !s.Valid { | |
181 | return "null" | |
182 | } | |
183 | val, err := s.CastTo(arrow.BinaryTypes.String) | |
184 | if err != nil { | |
185 | return "..." | |
186 | } | |
187 | return string(val.(*String).Value.Bytes()) | |
188 | } | |
189 | ||
190 | func (s *Struct) CastTo(to arrow.DataType) (Scalar, error) { | |
191 | if !s.Valid { | |
192 | return MakeNullScalar(to), nil | |
193 | } | |
194 | ||
195 | if to.ID() != arrow.STRING { | |
196 | return nil, xerrors.Errorf("cannot cast non-null struct scalar to type %s", to) | |
197 | } | |
198 | ||
199 | var bld bytes.Buffer | |
200 | st := s.Type.(*arrow.StructType) | |
201 | bld.WriteByte('{') | |
202 | for i, v := range s.Value { | |
203 | if i > 0 { | |
204 | bld.WriteString(", ") | |
205 | } | |
206 | bld.WriteString(fmt.Sprintf("%s:%s = %s", st.Field(i).Name, st.Field(i).Type, v.String())) | |
207 | } | |
208 | bld.WriteByte('}') | |
209 | buf := memory.NewBufferBytes(bld.Bytes()) | |
210 | defer buf.Release() | |
211 | return NewStringScalarFromBuffer(buf), nil | |
212 | } | |
213 | ||
214 | func (s *Struct) equals(rhs Scalar) bool { | |
215 | right := rhs.(*Struct) | |
216 | if len(s.Value) != len(right.Value) { | |
217 | return false | |
218 | } | |
219 | ||
220 | for i := range s.Value { | |
221 | if !Equals(s.Value[i], right.Value[i]) { | |
222 | return false | |
223 | } | |
224 | } | |
225 | return true | |
226 | } | |
227 | ||
228 | func (s *Struct) Validate() (err error) { | |
229 | if err = s.scalar.Validate(); err != nil { | |
230 | return | |
231 | } | |
232 | ||
233 | if !s.Valid { | |
234 | if len(s.Value) != 0 { | |
235 | err = xerrors.Errorf("%s scalar is marked null but has child values", s.Type) | |
236 | } | |
237 | return | |
238 | } | |
239 | ||
240 | st := s.Type.(*arrow.StructType) | |
241 | num := len(st.Fields()) | |
242 | if len(s.Value) != num { | |
243 | return xerrors.Errorf("non-null %s scalar should have %d child values, got %d", s.Type, num, len(s.Value)) | |
244 | } | |
245 | ||
246 | for i, f := range st.Fields() { | |
247 | if s.Value[i] == nil { | |
248 | return xerrors.Errorf("non-null %s scalar has missing child value at index %d", s.Type, i) | |
249 | } | |
250 | ||
251 | err = s.Value[i].Validate() | |
252 | if err != nil { | |
253 | return xerrors.Errorf("%s scalar fails validation for child at index %d: %w", s.Type, i, err) | |
254 | } | |
255 | ||
256 | if !arrow.TypeEqual(s.Value[i].DataType(), f.Type) { | |
257 | return xerrors.Errorf("%s scalar should have a child value of type %s at index %d, got %s", s.Type, f.Type, i, s.Value[i].DataType()) | |
258 | } | |
259 | } | |
260 | return | |
261 | } | |
262 | ||
263 | func (s *Struct) ValidateFull() (err error) { | |
264 | if err = s.scalar.ValidateFull(); err != nil { | |
265 | return | |
266 | } | |
267 | ||
268 | if !s.Valid { | |
269 | if len(s.Value) != 0 { | |
270 | err = xerrors.Errorf("%s scalar is marked null but has child values", s.Type) | |
271 | } | |
272 | return | |
273 | } | |
274 | ||
275 | st := s.Type.(*arrow.StructType) | |
276 | num := len(st.Fields()) | |
277 | if len(s.Value) != num { | |
278 | return xerrors.Errorf("non-null %s scalar should have %d child values, got %d", s.Type, num, len(s.Value)) | |
279 | } | |
280 | ||
281 | for i, f := range st.Fields() { | |
282 | if s.Value[i] == nil { | |
283 | return xerrors.Errorf("non-null %s scalar has missing child value at index %d", s.Type, i) | |
284 | } | |
285 | ||
286 | err = s.Value[i].ValidateFull() | |
287 | if err != nil { | |
288 | return xerrors.Errorf("%s scalar fails validation for child at index %d: %w", s.Type, i, err) | |
289 | } | |
290 | ||
291 | if !arrow.TypeEqual(s.Value[i].DataType(), f.Type) { | |
292 | return xerrors.Errorf("%s scalar should have a child value of type %s at index %d, got %s", s.Type, f.Type, i, s.Value[i].DataType()) | |
293 | } | |
294 | } | |
295 | return | |
296 | } | |
297 | ||
298 | func NewStructScalar(val []Scalar, typ arrow.DataType) *Struct { | |
299 | return &Struct{scalar{typ, true}, val} | |
300 | } | |
301 | ||
302 | func NewStructScalarWithNames(val []Scalar, names []string) (*Struct, error) { | |
303 | if len(val) != len(names) { | |
304 | return nil, xerrors.New("mismatching number of field names and child scalars") | |
305 | } | |
306 | ||
307 | fields := make([]arrow.Field, len(names)) | |
308 | for i, n := range names { | |
309 | fields[i] = arrow.Field{Name: n, Type: val[i].DataType(), Nullable: true} | |
310 | } | |
311 | return NewStructScalar(val, arrow.StructOf(fields...)), nil | |
312 | } |