]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/python/pyarrow/tests/test_plasma_tf_op.py
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / python / pyarrow / tests / test_plasma_tf_op.py
CommitLineData
1d09f67e
TL
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import numpy as np
19import pytest
20
21
22def run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
23 client, use_gpu, dtype):
24 FORCE_DEVICE = '/gpu' if use_gpu else '/cpu'
25
26 object_id = np.random.bytes(20)
27
28 data = np.random.randn(3, 244, 244).astype(dtype)
29 ones = np.ones((3, 244, 244)).astype(dtype)
30
31 sess = tf.Session(config=tf.ConfigProto(
32 allow_soft_placement=True, log_device_placement=True))
33
34 def ToPlasma():
35 data_tensor = tf.constant(data)
36 ones_tensor = tf.constant(ones)
37 return plasma.tf_plasma_op.tensor_to_plasma(
38 [data_tensor, ones_tensor],
39 object_id,
40 plasma_store_socket_name=plasma_store_name)
41
42 def FromPlasma():
43 return plasma.tf_plasma_op.plasma_to_tensor(
44 object_id,
45 dtype=tf.as_dtype(dtype),
46 plasma_store_socket_name=plasma_store_name)
47
48 with tf.device(FORCE_DEVICE):
49 to_plasma = ToPlasma()
50 from_plasma = FromPlasma()
51
52 z = from_plasma + 1
53
54 sess.run(to_plasma)
55 # NOTE(zongheng): currently it returns a flat 1D tensor.
56 # So reshape manually.
57 out = sess.run(from_plasma)
58
59 out = np.split(out, 2)
60 out0 = out[0].reshape(3, 244, 244)
61 out1 = out[1].reshape(3, 244, 244)
62
63 sess.run(z)
64
65 assert np.array_equal(data, out0), "Data not equal!"
66 assert np.array_equal(ones, out1), "Data not equal!"
67
68 # Try getting the data from Python
69 plasma_object_id = plasma.ObjectID(object_id)
70 obj = client.get(plasma_object_id)
71
72 # Deserialized Tensor should be 64-byte aligned.
73 assert obj.ctypes.data % 64 == 0
74
75 result = np.split(obj, 2)
76 result0 = result[0].reshape(3, 244, 244)
77 result1 = result[1].reshape(3, 244, 244)
78
79 assert np.array_equal(data, result0), "Data not equal!"
80 assert np.array_equal(ones, result1), "Data not equal!"
81
82
83@pytest.mark.plasma
84@pytest.mark.tensorflow
85@pytest.mark.skip(reason='Until ARROW-4259 is resolved')
86def test_plasma_tf_op(use_gpu=False):
87 import pyarrow.plasma as plasma
88 import tensorflow as tf
89
90 plasma.build_plasma_tensorflow_op()
91
92 if plasma.tf_plasma_op is None:
93 pytest.skip("TensorFlow Op not found")
94
95 with plasma.start_plasma_store(10**8) as (plasma_store_name, p):
96 client = plasma.connect(plasma_store_name)
97 for dtype in [np.float32, np.float64,
98 np.int8, np.int16, np.int32, np.int64]:
99 run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
100 client, use_gpu, dtype)
101
102 # Make sure the objects have been released.
103 for _, info in client.list().items():
104 assert info['ref_count'] == 0