]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/go/arrow/array/compare.go
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / go / arrow / array / compare.go
diff --git a/ceph/src/arrow/go/arrow/array/compare.go b/ceph/src/arrow/go/arrow/array/compare.go
new file mode 100644 (file)
index 0000000..ae6795d
--- /dev/null
@@ -0,0 +1,630 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package array
+
+import (
+       "math"
+
+       "github.com/apache/arrow/go/v6/arrow"
+       "github.com/apache/arrow/go/v6/arrow/float16"
+       "golang.org/x/xerrors"
+)
+
+// RecordEqual reports whether the two provided records are equal.
+func RecordEqual(left, right Record) bool {
+       switch {
+       case left.NumCols() != right.NumCols():
+               return false
+       case left.NumRows() != right.NumRows():
+               return false
+       }
+
+       for i := range left.Columns() {
+               lc := left.Column(i)
+               rc := right.Column(i)
+               if !ArrayEqual(lc, rc) {
+                       return false
+               }
+       }
+       return true
+}
+
+// RecordApproxEqual reports whether the two provided records are approximately equal.
+// For non-floating point columns, it is equivalent to RecordEqual.
+func RecordApproxEqual(left, right Record, opts ...EqualOption) bool {
+       switch {
+       case left.NumCols() != right.NumCols():
+               return false
+       case left.NumRows() != right.NumRows():
+               return false
+       }
+
+       opt := newEqualOption(opts...)
+
+       for i := range left.Columns() {
+               lc := left.Column(i)
+               rc := right.Column(i)
+               if !arrayApproxEqual(lc, rc, opt) {
+                       return false
+               }
+       }
+       return true
+}
+
+// helper function to evaluate a function on two chunked object having possibly different
+// chunk layouts. the function passed in will be called for each corresponding slice of the
+// two chunked arrays and if the function returns false it will end the loop early.
+func chunkedBinaryApply(left, right *Chunked, fn func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool) {
+       var (
+               pos               int64
+               length            int64 = int64(left.length)
+               leftIdx, rightIdx int
+               leftPos, rightPos int64
+       )
+
+       for pos < length {
+               var cleft, cright Interface
+               for {
+                       cleft, cright = left.Chunk(leftIdx), right.Chunk(rightIdx)
+                       if leftPos == int64(cleft.Len()) {
+                               leftPos = 0
+                               leftIdx++
+                               continue
+                       }
+                       if rightPos == int64(cright.Len()) {
+                               rightPos = 0
+                               rightIdx++
+                               continue
+                       }
+                       break
+               }
+
+               sz := int64(min(cleft.Len()-int(leftPos), cright.Len()-int(rightPos)))
+               pos += sz
+               if !fn(cleft, leftPos, leftPos+sz, cright, rightPos, rightPos+sz) {
+                       return
+               }
+
+               leftPos += sz
+               rightPos += sz
+       }
+}
+
+// ChunkedEqual reports whether two chunked arrays are equal regardless of their chunkings
+func ChunkedEqual(left, right *Chunked) bool {
+       switch {
+       case left == right:
+               return true
+       case left.length != right.length:
+               return false
+       case left.nulls != right.nulls:
+               return false
+       case !arrow.TypeEqual(left.dtype, right.dtype):
+               return false
+       }
+
+       var isequal bool
+       chunkedBinaryApply(left, right, func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool {
+               isequal = ArraySliceEqual(left, lbeg, lend, right, rbeg, rend)
+               return isequal
+       })
+
+       return isequal
+}
+
+// ChunkedApproxEqual reports whether two chunked arrays are approximately equal regardless of their chunkings
+// for non-floating point arrays, this is equivalent to ChunkedEqual
+func ChunkedApproxEqual(left, right *Chunked, opts ...EqualOption) bool {
+       switch {
+       case left == right:
+               return true
+       case left.length != right.length:
+               return false
+       case left.nulls != right.nulls:
+               return false
+       case !arrow.TypeEqual(left.dtype, right.dtype):
+               return false
+       }
+
+       var isequal bool
+       chunkedBinaryApply(left, right, func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool {
+               isequal = ArraySliceApproxEqual(left, lbeg, lend, right, rbeg, rend, opts...)
+               return isequal
+       })
+
+       return isequal
+}
+
+// TableEqual returns if the two tables have the same data in the same schema
+func TableEqual(left, right Table) bool {
+       switch {
+       case left.NumCols() != right.NumCols():
+               return false
+       case left.NumRows() != right.NumRows():
+               return false
+       }
+
+       for i := 0; int64(i) < left.NumCols(); i++ {
+               lc := left.Column(i)
+               rc := right.Column(i)
+               if !lc.field.Equal(rc.field) {
+                       return false
+               }
+
+               if !ChunkedEqual(lc.data, rc.data) {
+                       return false
+               }
+       }
+       return true
+}
+
+// TableEqual returns if the two tables have the approximately equal data in the same schema
+func TableApproxEqual(left, right Table, opts ...EqualOption) bool {
+       switch {
+       case left.NumCols() != right.NumCols():
+               return false
+       case left.NumRows() != right.NumRows():
+               return false
+       }
+
+       for i := 0; int64(i) < left.NumCols(); i++ {
+               lc := left.Column(i)
+               rc := right.Column(i)
+               if !lc.field.Equal(rc.field) {
+                       return false
+               }
+
+               if !ChunkedApproxEqual(lc.data, rc.data, opts...) {
+                       return false
+               }
+       }
+       return true
+}
+
+// ArrayEqual reports whether the two provided arrays are equal.
+func ArrayEqual(left, right Interface) bool {
+       switch {
+       case !baseArrayEqual(left, right):
+               return false
+       case left.Len() == 0:
+               return true
+       case left.NullN() == left.Len():
+               return true
+       }
+
+       // at this point, we know both arrays have same type, same length, same number of nulls
+       // and nulls at the same place.
+       // compare the values.
+
+       switch l := left.(type) {
+       case *Null:
+               return true
+       case *Boolean:
+               r := right.(*Boolean)
+               return arrayEqualBoolean(l, r)
+       case *FixedSizeBinary:
+               r := right.(*FixedSizeBinary)
+               return arrayEqualFixedSizeBinary(l, r)
+       case *Binary:
+               r := right.(*Binary)
+               return arrayEqualBinary(l, r)
+       case *String:
+               r := right.(*String)
+               return arrayEqualString(l, r)
+       case *Int8:
+               r := right.(*Int8)
+               return arrayEqualInt8(l, r)
+       case *Int16:
+               r := right.(*Int16)
+               return arrayEqualInt16(l, r)
+       case *Int32:
+               r := right.(*Int32)
+               return arrayEqualInt32(l, r)
+       case *Int64:
+               r := right.(*Int64)
+               return arrayEqualInt64(l, r)
+       case *Uint8:
+               r := right.(*Uint8)
+               return arrayEqualUint8(l, r)
+       case *Uint16:
+               r := right.(*Uint16)
+               return arrayEqualUint16(l, r)
+       case *Uint32:
+               r := right.(*Uint32)
+               return arrayEqualUint32(l, r)
+       case *Uint64:
+               r := right.(*Uint64)
+               return arrayEqualUint64(l, r)
+       case *Float16:
+               r := right.(*Float16)
+               return arrayEqualFloat16(l, r)
+       case *Float32:
+               r := right.(*Float32)
+               return arrayEqualFloat32(l, r)
+       case *Float64:
+               r := right.(*Float64)
+               return arrayEqualFloat64(l, r)
+       case *Decimal128:
+               r := right.(*Decimal128)
+               return arrayEqualDecimal128(l, r)
+       case *Date32:
+               r := right.(*Date32)
+               return arrayEqualDate32(l, r)
+       case *Date64:
+               r := right.(*Date64)
+               return arrayEqualDate64(l, r)
+       case *Time32:
+               r := right.(*Time32)
+               return arrayEqualTime32(l, r)
+       case *Time64:
+               r := right.(*Time64)
+               return arrayEqualTime64(l, r)
+       case *Timestamp:
+               r := right.(*Timestamp)
+               return arrayEqualTimestamp(l, r)
+       case *List:
+               r := right.(*List)
+               return arrayEqualList(l, r)
+       case *FixedSizeList:
+               r := right.(*FixedSizeList)
+               return arrayEqualFixedSizeList(l, r)
+       case *Struct:
+               r := right.(*Struct)
+               return arrayEqualStruct(l, r)
+       case *MonthInterval:
+               r := right.(*MonthInterval)
+               return arrayEqualMonthInterval(l, r)
+       case *DayTimeInterval:
+               r := right.(*DayTimeInterval)
+               return arrayEqualDayTimeInterval(l, r)
+       case *MonthDayNanoInterval:
+               r := right.(*MonthDayNanoInterval)
+               return arrayEqualMonthDayNanoInterval(l, r)
+       case *Duration:
+               r := right.(*Duration)
+               return arrayEqualDuration(l, r)
+       case *Map:
+               r := right.(*Map)
+               return arrayEqualMap(l, r)
+       case ExtensionArray:
+               r := right.(ExtensionArray)
+               return arrayEqualExtension(l, r)
+       default:
+               panic(xerrors.Errorf("arrow/array: unknown array type %T", l))
+       }
+}
+
+// ArraySliceEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are equal.
+func ArraySliceEqual(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool {
+       l := NewSlice(left, lbeg, lend)
+       defer l.Release()
+       r := NewSlice(right, rbeg, rend)
+       defer r.Release()
+
+       return ArrayEqual(l, r)
+}
+
+// ArraySliceApproxEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are approximately equal.
+func ArraySliceApproxEqual(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64, opts ...EqualOption) bool {
+       l := NewSlice(left, lbeg, lend)
+       defer l.Release()
+       r := NewSlice(right, rbeg, rend)
+       defer r.Release()
+
+       return ArrayApproxEqual(l, r, opts...)
+}
+
+const defaultAbsoluteTolerance = 1e-5
+
+type equalOption struct {
+       atol   float64 // absolute tolerance
+       nansEq bool    // whether NaNs are considered equal.
+}
+
+func (eq equalOption) f16(f1, f2 float16.Num) bool {
+       v1 := float64(f1.Float32())
+       v2 := float64(f2.Float32())
+       switch {
+       case eq.nansEq:
+               return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2))
+       default:
+               return math.Abs(v1-v2) <= eq.atol
+       }
+}
+
+func (eq equalOption) f32(f1, f2 float32) bool {
+       v1 := float64(f1)
+       v2 := float64(f2)
+       switch {
+       case eq.nansEq:
+               return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2))
+       default:
+               return math.Abs(v1-v2) <= eq.atol
+       }
+}
+
+func (eq equalOption) f64(v1, v2 float64) bool {
+       switch {
+       case eq.nansEq:
+               return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2))
+       default:
+               return math.Abs(v1-v2) <= eq.atol
+       }
+}
+
+func newEqualOption(opts ...EqualOption) equalOption {
+       eq := equalOption{
+               atol:   defaultAbsoluteTolerance,
+               nansEq: false,
+       }
+       for _, opt := range opts {
+               opt(&eq)
+       }
+
+       return eq
+}
+
+// EqualOption is a functional option type used to configure how Records and Arrays are compared.
+type EqualOption func(*equalOption)
+
+// WithNaNsEqual configures the comparison functions so that NaNs are considered equal.
+func WithNaNsEqual(v bool) EqualOption {
+       return func(o *equalOption) {
+               o.nansEq = v
+       }
+}
+
+// WithAbsTolerance configures the comparison functions so that 2 floating point values
+// v1 and v2 are considered equal if |v1-v2| <= atol.
+func WithAbsTolerance(atol float64) EqualOption {
+       return func(o *equalOption) {
+               o.atol = atol
+       }
+}
+
+// ArrayApproxEqual reports whether the two provided arrays are approximately equal.
+// For non-floating point arrays, it is equivalent to ArrayEqual.
+func ArrayApproxEqual(left, right Interface, opts ...EqualOption) bool {
+       opt := newEqualOption(opts...)
+       return arrayApproxEqual(left, right, opt)
+}
+
+func arrayApproxEqual(left, right Interface, opt equalOption) bool {
+       switch {
+       case !baseArrayEqual(left, right):
+               return false
+       case left.Len() == 0:
+               return true
+       case left.NullN() == left.Len():
+               return true
+       }
+
+       // at this point, we know both arrays have same type, same length, same number of nulls
+       // and nulls at the same place.
+       // compare the values.
+
+       switch l := left.(type) {
+       case *Null:
+               return true
+       case *Boolean:
+               r := right.(*Boolean)
+               return arrayEqualBoolean(l, r)
+       case *FixedSizeBinary:
+               r := right.(*FixedSizeBinary)
+               return arrayEqualFixedSizeBinary(l, r)
+       case *Binary:
+               r := right.(*Binary)
+               return arrayEqualBinary(l, r)
+       case *String:
+               r := right.(*String)
+               return arrayEqualString(l, r)
+       case *Int8:
+               r := right.(*Int8)
+               return arrayEqualInt8(l, r)
+       case *Int16:
+               r := right.(*Int16)
+               return arrayEqualInt16(l, r)
+       case *Int32:
+               r := right.(*Int32)
+               return arrayEqualInt32(l, r)
+       case *Int64:
+               r := right.(*Int64)
+               return arrayEqualInt64(l, r)
+       case *Uint8:
+               r := right.(*Uint8)
+               return arrayEqualUint8(l, r)
+       case *Uint16:
+               r := right.(*Uint16)
+               return arrayEqualUint16(l, r)
+       case *Uint32:
+               r := right.(*Uint32)
+               return arrayEqualUint32(l, r)
+       case *Uint64:
+               r := right.(*Uint64)
+               return arrayEqualUint64(l, r)
+       case *Float16:
+               r := right.(*Float16)
+               return arrayApproxEqualFloat16(l, r, opt)
+       case *Float32:
+               r := right.(*Float32)
+               return arrayApproxEqualFloat32(l, r, opt)
+       case *Float64:
+               r := right.(*Float64)
+               return arrayApproxEqualFloat64(l, r, opt)
+       case *Decimal128:
+               r := right.(*Decimal128)
+               return arrayEqualDecimal128(l, r)
+       case *Date32:
+               r := right.(*Date32)
+               return arrayEqualDate32(l, r)
+       case *Date64:
+               r := right.(*Date64)
+               return arrayEqualDate64(l, r)
+       case *Time32:
+               r := right.(*Time32)
+               return arrayEqualTime32(l, r)
+       case *Time64:
+               r := right.(*Time64)
+               return arrayEqualTime64(l, r)
+       case *Timestamp:
+               r := right.(*Timestamp)
+               return arrayEqualTimestamp(l, r)
+       case *List:
+               r := right.(*List)
+               return arrayApproxEqualList(l, r, opt)
+       case *FixedSizeList:
+               r := right.(*FixedSizeList)
+               return arrayApproxEqualFixedSizeList(l, r, opt)
+       case *Struct:
+               r := right.(*Struct)
+               return arrayApproxEqualStruct(l, r, opt)
+       case *MonthInterval:
+               r := right.(*MonthInterval)
+               return arrayEqualMonthInterval(l, r)
+       case *DayTimeInterval:
+               r := right.(*DayTimeInterval)
+               return arrayEqualDayTimeInterval(l, r)
+       case *MonthDayNanoInterval:
+               r := right.(*MonthDayNanoInterval)
+               return arrayEqualMonthDayNanoInterval(l, r)
+       case *Duration:
+               r := right.(*Duration)
+               return arrayEqualDuration(l, r)
+       case *Map:
+               r := right.(*Map)
+               return arrayApproxEqualList(l.List, r.List, opt)
+       case ExtensionArray:
+               r := right.(ExtensionArray)
+               return arrayApproxEqualExtension(l, r, opt)
+       default:
+               panic(xerrors.Errorf("arrow/array: unknown array type %T", l))
+       }
+
+       return false
+}
+
+func baseArrayEqual(left, right Interface) bool {
+       switch {
+       case left.Len() != right.Len():
+               return false
+       case left.NullN() != right.NullN():
+               return false
+       case !arrow.TypeEqual(left.DataType(), right.DataType()): // We do not check for metadata as in the C++ implementation.
+               return false
+       case !validityBitmapEqual(left, right):
+               return false
+       }
+       return true
+}
+
+func validityBitmapEqual(left, right Interface) bool {
+       // TODO(alexandreyc): make it faster by comparing byte slices of the validity bitmap?
+       n := left.Len()
+       if n != right.Len() {
+               return false
+       }
+       for i := 0; i < n; i++ {
+               if left.IsNull(i) != right.IsNull(i) {
+                       return false
+               }
+       }
+       return true
+}
+
+func arrayApproxEqualFloat16(left, right *Float16, opt equalOption) bool {
+       for i := 0; i < left.Len(); i++ {
+               if left.IsNull(i) {
+                       continue
+               }
+               if !opt.f16(left.Value(i), right.Value(i)) {
+                       return false
+               }
+       }
+       return true
+}
+
+func arrayApproxEqualFloat32(left, right *Float32, opt equalOption) bool {
+       for i := 0; i < left.Len(); i++ {
+               if left.IsNull(i) {
+                       continue
+               }
+               if !opt.f32(left.Value(i), right.Value(i)) {
+                       return false
+               }
+       }
+       return true
+}
+
+func arrayApproxEqualFloat64(left, right *Float64, opt equalOption) bool {
+       for i := 0; i < left.Len(); i++ {
+               if left.IsNull(i) {
+                       continue
+               }
+               if !opt.f64(left.Value(i), right.Value(i)) {
+                       return false
+               }
+       }
+       return true
+}
+
+func arrayApproxEqualList(left, right *List, opt equalOption) bool {
+       for i := 0; i < left.Len(); i++ {
+               if left.IsNull(i) {
+                       continue
+               }
+               o := func() bool {
+                       l := left.newListValue(i)
+                       defer l.Release()
+                       r := right.newListValue(i)
+                       defer r.Release()
+                       return arrayApproxEqual(l, r, opt)
+               }()
+               if !o {
+                       return false
+               }
+       }
+       return true
+}
+
+func arrayApproxEqualFixedSizeList(left, right *FixedSizeList, opt equalOption) bool {
+       for i := 0; i < left.Len(); i++ {
+               if left.IsNull(i) {
+                       continue
+               }
+               o := func() bool {
+                       l := left.newListValue(i)
+                       defer l.Release()
+                       r := right.newListValue(i)
+                       defer r.Release()
+                       return arrayApproxEqual(l, r, opt)
+               }()
+               if !o {
+                       return false
+               }
+       }
+       return true
+}
+
+func arrayApproxEqualStruct(left, right *Struct, opt equalOption) bool {
+       for i, lf := range left.fields {
+               rf := right.fields[i]
+               if !arrayApproxEqual(lf, rf, opt) {
+                       return false
+               }
+       }
+       return true
+}