]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | // Functions for comparing Arrow data structures | |
19 | ||
20 | #pragma once | |
21 | ||
22 | #include <cstdint> | |
23 | #include <iosfwd> | |
24 | ||
25 | #include "arrow/util/macros.h" | |
26 | #include "arrow/util/visibility.h" | |
27 | ||
28 | namespace arrow { | |
29 | ||
30 | class Array; | |
31 | class DataType; | |
32 | class Tensor; | |
33 | class SparseTensor; | |
34 | struct Scalar; | |
35 | ||
36 | static constexpr double kDefaultAbsoluteTolerance = 1E-5; | |
37 | ||
38 | /// A container of options for equality comparisons | |
39 | class EqualOptions { | |
40 | public: | |
41 | /// Whether or not NaNs are considered equal. | |
42 | bool nans_equal() const { return nans_equal_; } | |
43 | ||
44 | /// Return a new EqualOptions object with the "nans_equal" property changed. | |
45 | EqualOptions nans_equal(bool v) const { | |
46 | auto res = EqualOptions(*this); | |
47 | res.nans_equal_ = v; | |
48 | return res; | |
49 | } | |
50 | ||
51 | /// The absolute tolerance for approximate comparisons of floating-point values. | |
52 | double atol() const { return atol_; } | |
53 | ||
54 | /// Return a new EqualOptions object with the "atol" property changed. | |
55 | EqualOptions atol(double v) const { | |
56 | auto res = EqualOptions(*this); | |
57 | res.atol_ = v; | |
58 | return res; | |
59 | } | |
60 | ||
61 | /// The ostream to which a diff will be formatted if arrays disagree. | |
62 | /// If this is null (the default) no diff will be formatted. | |
63 | std::ostream* diff_sink() const { return diff_sink_; } | |
64 | ||
65 | /// Return a new EqualOptions object with the "diff_sink" property changed. | |
66 | /// This option will be ignored if diff formatting of the types of compared arrays is | |
67 | /// not supported. | |
68 | EqualOptions diff_sink(std::ostream* diff_sink) const { | |
69 | auto res = EqualOptions(*this); | |
70 | res.diff_sink_ = diff_sink; | |
71 | return res; | |
72 | } | |
73 | ||
74 | static EqualOptions Defaults() { return {}; } | |
75 | ||
76 | protected: | |
77 | double atol_ = kDefaultAbsoluteTolerance; | |
78 | bool nans_equal_ = false; | |
79 | std::ostream* diff_sink_ = NULLPTR; | |
80 | }; | |
81 | ||
82 | /// Returns true if the arrays are exactly equal | |
83 | bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right, | |
84 | const EqualOptions& = EqualOptions::Defaults()); | |
85 | ||
86 | /// Returns true if the arrays are approximately equal. For non-floating point | |
87 | /// types, this is equivalent to ArrayEquals(left, right) | |
88 | bool ARROW_EXPORT ArrayApproxEquals(const Array& left, const Array& right, | |
89 | const EqualOptions& = EqualOptions::Defaults()); | |
90 | ||
91 | /// Returns true if indicated equal-length segment of arrays are exactly equal | |
92 | bool ARROW_EXPORT ArrayRangeEquals(const Array& left, const Array& right, | |
93 | int64_t start_idx, int64_t end_idx, | |
94 | int64_t other_start_idx, | |
95 | const EqualOptions& = EqualOptions::Defaults()); | |
96 | ||
97 | /// Returns true if indicated equal-length segment of arrays are approximately equal | |
98 | bool ARROW_EXPORT ArrayRangeApproxEquals(const Array& left, const Array& right, | |
99 | int64_t start_idx, int64_t end_idx, | |
100 | int64_t other_start_idx, | |
101 | const EqualOptions& = EqualOptions::Defaults()); | |
102 | ||
103 | bool ARROW_EXPORT TensorEquals(const Tensor& left, const Tensor& right, | |
104 | const EqualOptions& = EqualOptions::Defaults()); | |
105 | ||
106 | /// EXPERIMENTAL: Returns true if the given sparse tensors are exactly equal | |
107 | bool ARROW_EXPORT SparseTensorEquals(const SparseTensor& left, const SparseTensor& right, | |
108 | const EqualOptions& = EqualOptions::Defaults()); | |
109 | ||
110 | /// Returns true if the type metadata are exactly equal | |
111 | /// \param[in] left a DataType | |
112 | /// \param[in] right a DataType | |
113 | /// \param[in] check_metadata whether to compare KeyValueMetadata for child | |
114 | /// fields | |
115 | bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right, | |
116 | bool check_metadata = true); | |
117 | ||
118 | /// Returns true if scalars are equal | |
119 | /// \param[in] left a Scalar | |
120 | /// \param[in] right a Scalar | |
121 | /// \param[in] options comparison options | |
122 | bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right, | |
123 | const EqualOptions& options = EqualOptions::Defaults()); | |
124 | ||
125 | /// Returns true if scalars are approximately equal | |
126 | /// \param[in] left a Scalar | |
127 | /// \param[in] right a Scalar | |
128 | /// \param[in] options comparison options | |
129 | bool ARROW_EXPORT | |
130 | ScalarApproxEquals(const Scalar& left, const Scalar& right, | |
131 | const EqualOptions& options = EqualOptions::Defaults()); | |
132 | ||
133 | } // namespace arrow |