]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | # Licensed to the Apache Software Foundation (ASF) under one |
2 | # or more contributor license agreements. See the NOTICE file | |
3 | # distributed with this work for additional information | |
4 | # regarding copyright ownership. The ASF licenses this file | |
5 | # to you under the Apache License, Version 2.0 (the | |
6 | # "License"); you may not use this file except in compliance | |
7 | # with the License. You may obtain a copy of the License at | |
8 | # | |
9 | # http://www.apache.org/licenses/LICENSE-2.0 | |
10 | # | |
11 | # Unless required by applicable law or agreed to in writing, | |
12 | # software distributed under the License is distributed on an | |
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | # KIND, either express or implied. See the License for the | |
15 | # specific language governing permissions and limitations | |
16 | # under the License. | |
17 | ||
18 | import numpy as np | |
19 | import pytest | |
20 | ||
21 | ||
22 | def run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name, | |
23 | client, use_gpu, dtype): | |
24 | FORCE_DEVICE = '/gpu' if use_gpu else '/cpu' | |
25 | ||
26 | object_id = np.random.bytes(20) | |
27 | ||
28 | data = np.random.randn(3, 244, 244).astype(dtype) | |
29 | ones = np.ones((3, 244, 244)).astype(dtype) | |
30 | ||
31 | sess = tf.Session(config=tf.ConfigProto( | |
32 | allow_soft_placement=True, log_device_placement=True)) | |
33 | ||
34 | def ToPlasma(): | |
35 | data_tensor = tf.constant(data) | |
36 | ones_tensor = tf.constant(ones) | |
37 | return plasma.tf_plasma_op.tensor_to_plasma( | |
38 | [data_tensor, ones_tensor], | |
39 | object_id, | |
40 | plasma_store_socket_name=plasma_store_name) | |
41 | ||
42 | def FromPlasma(): | |
43 | return plasma.tf_plasma_op.plasma_to_tensor( | |
44 | object_id, | |
45 | dtype=tf.as_dtype(dtype), | |
46 | plasma_store_socket_name=plasma_store_name) | |
47 | ||
48 | with tf.device(FORCE_DEVICE): | |
49 | to_plasma = ToPlasma() | |
50 | from_plasma = FromPlasma() | |
51 | ||
52 | z = from_plasma + 1 | |
53 | ||
54 | sess.run(to_plasma) | |
55 | # NOTE(zongheng): currently it returns a flat 1D tensor. | |
56 | # So reshape manually. | |
57 | out = sess.run(from_plasma) | |
58 | ||
59 | out = np.split(out, 2) | |
60 | out0 = out[0].reshape(3, 244, 244) | |
61 | out1 = out[1].reshape(3, 244, 244) | |
62 | ||
63 | sess.run(z) | |
64 | ||
65 | assert np.array_equal(data, out0), "Data not equal!" | |
66 | assert np.array_equal(ones, out1), "Data not equal!" | |
67 | ||
68 | # Try getting the data from Python | |
69 | plasma_object_id = plasma.ObjectID(object_id) | |
70 | obj = client.get(plasma_object_id) | |
71 | ||
72 | # Deserialized Tensor should be 64-byte aligned. | |
73 | assert obj.ctypes.data % 64 == 0 | |
74 | ||
75 | result = np.split(obj, 2) | |
76 | result0 = result[0].reshape(3, 244, 244) | |
77 | result1 = result[1].reshape(3, 244, 244) | |
78 | ||
79 | assert np.array_equal(data, result0), "Data not equal!" | |
80 | assert np.array_equal(ones, result1), "Data not equal!" | |
81 | ||
82 | ||
83 | @pytest.mark.plasma | |
84 | @pytest.mark.tensorflow | |
85 | @pytest.mark.skip(reason='Until ARROW-4259 is resolved') | |
86 | def test_plasma_tf_op(use_gpu=False): | |
87 | import pyarrow.plasma as plasma | |
88 | import tensorflow as tf | |
89 | ||
90 | plasma.build_plasma_tensorflow_op() | |
91 | ||
92 | if plasma.tf_plasma_op is None: | |
93 | pytest.skip("TensorFlow Op not found") | |
94 | ||
95 | with plasma.start_plasma_store(10**8) as (plasma_store_name, p): | |
96 | client = plasma.connect(plasma_store_name) | |
97 | for dtype in [np.float32, np.float64, | |
98 | np.int8, np.int16, np.int32, np.int64]: | |
99 | run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name, | |
100 | client, use_gpu, dtype) | |
101 | ||
102 | # Make sure the objects have been released. | |
103 | for _, info in client.list().items(): | |
104 | assert info['ref_count'] == 0 |