]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/c_glib/test/test-tensor.rb
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / c_glib / test / test-tensor.rb
1 # Licensed to the Apache Software Foundation (ASF) under one
2 # or more contributor license agreements. See the NOTICE file
3 # distributed with this work for additional information
4 # regarding copyright ownership. The ASF licenses this file
5 # to you under the Apache License, Version 2.0 (the
6 # "License"); you may not use this file except in compliance
7 # with the License. You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing,
12 # software distributed under the License is distributed on an
13 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 # KIND, either express or implied. See the License for the
15 # specific language governing permissions and limitations
16 # under the License.
17
18 class TestTensor < Test::Unit::TestCase
19 include Helper::Omittable
20
21 def setup
22 @raw_data = [
23 1, 2,
24 3, 4,
25
26 5, 6,
27 7, 8,
28
29 9, 10,
30 11, 12,
31 ]
32 data = Arrow::Buffer.new(@raw_data.pack("c*"))
33 @shape = [3, 2, 2]
34 strides = []
35 names = ["a", "b", "c"]
36 @tensor = Arrow::Tensor.new(Arrow::Int8DataType.new,
37 data,
38 @shape,
39 strides,
40 names)
41 end
42
43 def test_equal
44 data = Arrow::Buffer.new(@raw_data.pack("c*"))
45 strides = []
46 names = ["a", "b", "c"]
47 other_tensor = Arrow::Tensor.new(Arrow::Int8DataType.new,
48 data,
49 @shape,
50 strides,
51 names)
52 assert_equal(@tensor,
53 other_tensor)
54 end
55
56 def test_value_data_type
57 assert_equal(Arrow::Int8DataType, @tensor.value_data_type.class)
58 end
59
60 def test_value_type
61 assert_equal(Arrow::Type::INT8, @tensor.value_type)
62 end
63
64 def test_buffer
65 assert_equal(@raw_data, @tensor.buffer.data.to_s.unpack("c*"))
66 end
67
68 def test_shape
69 require_gi_bindings(3, 3, 1)
70 assert_equal(@shape, @tensor.shape)
71 end
72
73 def test_strides
74 require_gi_bindings(3, 3, 1)
75 assert_equal([4, 2, 1], @tensor.strides)
76 end
77
78 def test_n_dimensions
79 assert_equal(@shape.size, @tensor.n_dimensions)
80 end
81
82 def test_dimension_name
83 dimension_names = @tensor.n_dimensions.times.collect do |i|
84 @tensor.get_dimension_name(i)
85 end
86 assert_equal(["a", "b", "c"],
87 dimension_names)
88 end
89
90 def test_size
91 assert_equal(@raw_data.size, @tensor.size)
92 end
93
94 def test_mutable?
95 assert do
96 not @tensor.mutable?
97 end
98 end
99
100 def test_contiguous?
101 assert do
102 @tensor.contiguous?
103 end
104 end
105
106 def test_row_major?
107 assert do
108 @tensor.row_major?
109 end
110 end
111
112 def test_column_major?
113 assert do
114 not @tensor.column_major?
115 end
116 end
117
118 def test_io
119 buffer = Arrow::ResizableBuffer.new(0)
120 output = Arrow::BufferOutputStream.new(buffer)
121 output.write_tensor(@tensor)
122 input = Arrow::BufferInputStream.new(buffer)
123 assert_equal(@tensor, input.read_tensor)
124 end
125 end