]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/go/arrow/tensor/tensor.go
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / go / arrow / tensor / tensor.go
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 // Package tensor provides types that implement n-dimensional arrays.
18 package tensor
19
20 import (
21 "fmt"
22 "sync/atomic"
23
24 "github.com/apache/arrow/go/v6/arrow"
25 "github.com/apache/arrow/go/v6/arrow/array"
26 "github.com/apache/arrow/go/v6/arrow/internal/debug"
27 )
28
29 // Interface represents an n-dimensional array of numerical data.
30 type Interface interface {
31 // Retain increases the reference count by 1.
32 // Retain may be called simultaneously from multiple goroutines.
33 Retain()
34
35 // Release decreases the reference count by 1.
36 // Release may be called simultaneously from multiple goroutines.
37 // When the reference count goes to zero, the memory is freed.
38 Release()
39
40 // Len returns the number of elements in the tensor.
41 Len() int
42
43 // Shape returns the size - in each dimension - of the tensor.
44 Shape() []int64
45
46 // Strides returns the number of bytes to step in each dimension when traversing the tensor.
47 Strides() []int64
48
49 // NumDims returns the number of dimensions of the tensor.
50 NumDims() int
51
52 // DimName returns the name of the i-th dimension.
53 DimName(i int) string
54
55 // DimNames returns the names for all dimensions
56 DimNames() []string
57
58 DataType() arrow.DataType
59 Data() *array.Data
60
61 // IsMutable returns whether the underlying data buffer is mutable.
62 IsMutable() bool
63 IsContiguous() bool
64 IsRowMajor() bool
65 IsColMajor() bool
66 }
67
68 type tensorBase struct {
69 refCount int64
70 dtype arrow.DataType
71 bw int64 // bytes width
72 data *array.Data
73 shape []int64
74 strides []int64
75 names []string
76 }
77
78 // Retain increases the reference count by 1.
79 // Retain may be called simultaneously from multiple goroutines.
80 func (tb *tensorBase) Retain() {
81 atomic.AddInt64(&tb.refCount, 1)
82 }
83
84 // Release decreases the reference count by 1.
85 // Release may be called simultaneously from multiple goroutines.
86 // When the reference count goes to zero, the memory is freed.
87 func (tb *tensorBase) Release() {
88 debug.Assert(atomic.LoadInt64(&tb.refCount) > 0, "too many releases")
89
90 if atomic.AddInt64(&tb.refCount, -1) == 0 {
91 tb.data.Release()
92 tb.data = nil
93 }
94 }
95
96 func (tb *tensorBase) Len() int {
97 o := int64(1)
98 for _, v := range tb.shape {
99 o *= v
100 }
101 return int(o)
102 }
103
104 func (tb *tensorBase) Shape() []int64 { return tb.shape }
105 func (tb *tensorBase) Strides() []int64 { return tb.strides }
106 func (tb *tensorBase) NumDims() int { return len(tb.shape) }
107 func (tb *tensorBase) DimName(i int) string { return tb.names[i] }
108 func (tb *tensorBase) DataType() arrow.DataType { return tb.dtype }
109 func (tb *tensorBase) Data() *array.Data { return tb.data }
110 func (tb *tensorBase) DimNames() []string { return tb.names }
111
112 // IsMutable returns whether the underlying data buffer is mutable.
113 func (tb *tensorBase) IsMutable() bool { return false } // FIXME(sbinet): implement it at the array.Data level
114
115 func (tb *tensorBase) IsContiguous() bool {
116 return tb.IsRowMajor() || tb.IsColMajor()
117 }
118
119 func (tb *tensorBase) IsRowMajor() bool {
120 strides := rowMajorStrides(tb.dtype, tb.shape)
121 return equalInt64s(strides, tb.strides)
122 }
123
124 func (tb *tensorBase) IsColMajor() bool {
125 strides := colMajorStrides(tb.dtype, tb.shape)
126 return equalInt64s(strides, tb.strides)
127 }
128
129 func (tb *tensorBase) offset(index []int64) int64 {
130 var offset int64
131 for i, v := range index {
132 offset += v * tb.strides[i]
133 }
134 return offset / tb.bw
135 }
136
137 // New returns a new n-dim array from the provided backing data and the shape and strides.
138 // If strides is nil, row-major strides will be inferred.
139 // If names is nil, a slice of empty strings will be created.
140 //
141 // New panics if the backing data is not a numerical type.
142 func New(data *array.Data, shape, strides []int64, names []string) Interface {
143 dt := data.DataType()
144 switch dt.ID() {
145 case arrow.INT8:
146 return NewInt8(data, shape, strides, names)
147 case arrow.INT16:
148 return NewInt16(data, shape, strides, names)
149 case arrow.INT32:
150 return NewInt32(data, shape, strides, names)
151 case arrow.INT64:
152 return NewInt64(data, shape, strides, names)
153 case arrow.UINT8:
154 return NewUint8(data, shape, strides, names)
155 case arrow.UINT16:
156 return NewUint16(data, shape, strides, names)
157 case arrow.UINT32:
158 return NewUint32(data, shape, strides, names)
159 case arrow.UINT64:
160 return NewUint64(data, shape, strides, names)
161 case arrow.FLOAT32:
162 return NewFloat32(data, shape, strides, names)
163 case arrow.FLOAT64:
164 return NewFloat64(data, shape, strides, names)
165 case arrow.DATE32:
166 return NewDate32(data, shape, strides, names)
167 case arrow.DATE64:
168 return NewDate64(data, shape, strides, names)
169 default:
170 panic(fmt.Errorf("arrow/tensor: invalid data type %s", dt.Name()))
171 }
172 }
173
174 func newTensor(dtype arrow.DataType, data *array.Data, shape, strides []int64, names []string) *tensorBase {
175 tb := tensorBase{
176 refCount: 1,
177 dtype: dtype,
178 bw: int64(dtype.(arrow.FixedWidthDataType).BitWidth()) / 8,
179 data: data,
180 shape: shape,
181 strides: strides,
182 names: names,
183 }
184 tb.data.Retain()
185
186 if len(tb.shape) > 0 && len(tb.strides) == 0 {
187 tb.strides = rowMajorStrides(dtype, shape)
188 }
189 return &tb
190 }
191
192 func rowMajorStrides(dtype arrow.DataType, shape []int64) []int64 {
193 dt := dtype.(arrow.FixedWidthDataType)
194 rem := int64(dt.BitWidth() / 8)
195 for _, v := range shape {
196 rem *= v
197 }
198
199 if rem == 0 {
200 strides := make([]int64, len(shape))
201 rem := int64(dt.BitWidth() / 8)
202 for i := range strides {
203 strides[i] = rem
204 }
205 return strides
206 }
207
208 var strides []int64
209 for _, v := range shape {
210 rem /= v
211 strides = append(strides, rem)
212 }
213 return strides
214 }
215
216 func colMajorStrides(dtype arrow.DataType, shape []int64) []int64 {
217 dt := dtype.(arrow.FixedWidthDataType)
218 total := int64(dt.BitWidth() / 8)
219 for _, v := range shape {
220 if v == 0 {
221 strides := make([]int64, len(shape))
222 for i := range strides {
223 strides[i] = total
224 }
225 return strides
226 }
227 }
228
229 var strides []int64
230 for _, v := range shape {
231 strides = append(strides, total)
232 total *= v
233 }
234 return strides
235 }
236
237 func equalInt64s(a, b []int64) bool {
238 if len(a) != len(b) {
239 return false
240 }
241 for i := range a {
242 if a[i] != b[i] {
243 return false
244 }
245 }
246 return true
247 }