1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 // Package tensor provides types that implement n-dimensional arrays.
24 "github.com/apache/arrow/go/v6/arrow"
25 "github.com/apache/arrow/go/v6/arrow/array"
26 "github.com/apache/arrow/go/v6/arrow/internal/debug"
29 // Interface represents an n-dimensional array of numerical data.
30 type Interface interface {
31 // Retain increases the reference count by 1.
32 // Retain may be called simultaneously from multiple goroutines.
35 // Release decreases the reference count by 1.
36 // Release may be called simultaneously from multiple goroutines.
37 // When the reference count goes to zero, the memory is freed.
40 // Len returns the number of elements in the tensor.
43 // Shape returns the size - in each dimension - of the tensor.
46 // Strides returns the number of bytes to step in each dimension when traversing the tensor.
49 // NumDims returns the number of dimensions of the tensor.
52 // DimName returns the name of the i-th dimension.
55 // DimNames returns the names for all dimensions
58 DataType() arrow.DataType
61 // IsMutable returns whether the underlying data buffer is mutable.
68 type tensorBase struct {
71 bw int64 // bytes width
78 // Retain increases the reference count by 1.
79 // Retain may be called simultaneously from multiple goroutines.
80 func (tb *tensorBase) Retain() {
81 atomic.AddInt64(&tb.refCount, 1)
84 // Release decreases the reference count by 1.
85 // Release may be called simultaneously from multiple goroutines.
86 // When the reference count goes to zero, the memory is freed.
87 func (tb *tensorBase) Release() {
88 debug.Assert(atomic.LoadInt64(&tb.refCount) > 0, "too many releases")
90 if atomic.AddInt64(&tb.refCount, -1) == 0 {
96 func (tb *tensorBase) Len() int {
98 for _, v := range tb.shape {
104 func (tb *tensorBase) Shape() []int64 { return tb.shape }
105 func (tb *tensorBase) Strides() []int64 { return tb.strides }
106 func (tb *tensorBase) NumDims() int { return len(tb.shape) }
107 func (tb *tensorBase) DimName(i int) string { return tb.names[i] }
108 func (tb *tensorBase) DataType() arrow.DataType { return tb.dtype }
109 func (tb *tensorBase) Data() *array.Data { return tb.data }
110 func (tb *tensorBase) DimNames() []string { return tb.names }
112 // IsMutable returns whether the underlying data buffer is mutable.
113 func (tb *tensorBase) IsMutable() bool { return false } // FIXME(sbinet): implement it at the array.Data level
115 func (tb *tensorBase) IsContiguous() bool {
116 return tb.IsRowMajor() || tb.IsColMajor()
119 func (tb *tensorBase) IsRowMajor() bool {
120 strides := rowMajorStrides(tb.dtype, tb.shape)
121 return equalInt64s(strides, tb.strides)
124 func (tb *tensorBase) IsColMajor() bool {
125 strides := colMajorStrides(tb.dtype, tb.shape)
126 return equalInt64s(strides, tb.strides)
129 func (tb *tensorBase) offset(index []int64) int64 {
131 for i, v := range index {
132 offset += v * tb.strides[i]
134 return offset / tb.bw
137 // New returns a new n-dim array from the provided backing data and the shape and strides.
138 // If strides is nil, row-major strides will be inferred.
139 // If names is nil, a slice of empty strings will be created.
141 // New panics if the backing data is not a numerical type.
142 func New(data *array.Data, shape, strides []int64, names []string) Interface {
143 dt := data.DataType()
146 return NewInt8(data, shape, strides, names)
148 return NewInt16(data, shape, strides, names)
150 return NewInt32(data, shape, strides, names)
152 return NewInt64(data, shape, strides, names)
154 return NewUint8(data, shape, strides, names)
156 return NewUint16(data, shape, strides, names)
158 return NewUint32(data, shape, strides, names)
160 return NewUint64(data, shape, strides, names)
162 return NewFloat32(data, shape, strides, names)
164 return NewFloat64(data, shape, strides, names)
166 return NewDate32(data, shape, strides, names)
168 return NewDate64(data, shape, strides, names)
170 panic(fmt.Errorf("arrow/tensor: invalid data type %s", dt.Name()))
174 func newTensor(dtype arrow.DataType, data *array.Data, shape, strides []int64, names []string) *tensorBase {
178 bw: int64(dtype.(arrow.FixedWidthDataType).BitWidth()) / 8,
186 if len(tb.shape) > 0 && len(tb.strides) == 0 {
187 tb.strides = rowMajorStrides(dtype, shape)
192 func rowMajorStrides(dtype arrow.DataType, shape []int64) []int64 {
193 dt := dtype.(arrow.FixedWidthDataType)
194 rem := int64(dt.BitWidth() / 8)
195 for _, v := range shape {
200 strides := make([]int64, len(shape))
201 rem := int64(dt.BitWidth() / 8)
202 for i := range strides {
209 for _, v := range shape {
211 strides = append(strides, rem)
216 func colMajorStrides(dtype arrow.DataType, shape []int64) []int64 {
217 dt := dtype.(arrow.FixedWidthDataType)
218 total := int64(dt.BitWidth() / 8)
219 for _, v := range shape {
221 strides := make([]int64, len(shape))
222 for i := range strides {
230 for _, v := range shape {
231 strides = append(strides, total)
237 func equalInt64s(a, b []int64) bool {
238 if len(a) != len(b) {