]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | # Licensed to the Apache Software Foundation (ASF) under one |
2 | # or more contributor license agreements. See the NOTICE file | |
3 | # distributed with this work for additional information | |
4 | # regarding copyright ownership. The ASF licenses this file | |
5 | # to you under the Apache License, Version 2.0 (the | |
6 | # "License"); you may not use this file except in compliance | |
7 | # with the License. You may obtain a copy of the License at | |
8 | # | |
9 | # http://www.apache.org/licenses/LICENSE-2.0 | |
10 | # | |
11 | # Unless required by applicable law or agreed to in writing, | |
12 | # software distributed under the License is distributed on an | |
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | # KIND, either express or implied. See the License for the | |
15 | # specific language governing permissions and limitations | |
16 | # under the License. | |
17 | ||
18 | import os | |
19 | import shutil | |
20 | import subprocess | |
21 | import sys | |
22 | ||
23 | import pytest | |
24 | ||
25 | import pyarrow as pa | |
26 | import pyarrow.tests.util as test_util | |
27 | ||
28 | ||
29 | here = os.path.dirname(os.path.abspath(__file__)) | |
30 | test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '') | |
31 | if os.name == 'posix': | |
32 | compiler_opts = ['-std=c++11'] | |
33 | else: | |
34 | compiler_opts = [] | |
35 | ||
36 | ||
37 | setup_template = """if 1: | |
38 | from setuptools import setup | |
39 | from Cython.Build import cythonize | |
40 | ||
41 | import numpy as np | |
42 | ||
43 | import pyarrow as pa | |
44 | ||
45 | ext_modules = cythonize({pyx_file!r}) | |
46 | compiler_opts = {compiler_opts!r} | |
47 | custom_ld_path = {test_ld_path!r} | |
48 | ||
49 | for ext in ext_modules: | |
50 | # XXX required for numpy/numpyconfig.h, | |
51 | # included from arrow/python/api.h | |
52 | ext.include_dirs.append(np.get_include()) | |
53 | ext.include_dirs.append(pa.get_include()) | |
54 | ext.libraries.extend(pa.get_libraries()) | |
55 | ext.library_dirs.extend(pa.get_library_dirs()) | |
56 | if custom_ld_path: | |
57 | ext.library_dirs.append(custom_ld_path) | |
58 | ext.extra_compile_args.extend(compiler_opts) | |
59 | print("Extension module:", | |
60 | ext, ext.include_dirs, ext.libraries, ext.library_dirs) | |
61 | ||
62 | setup( | |
63 | ext_modules=ext_modules, | |
64 | ) | |
65 | """ | |
66 | ||
67 | ||
68 | def check_cython_example_module(mod): | |
69 | arr = pa.array([1, 2, 3]) | |
70 | assert mod.get_array_length(arr) == 3 | |
71 | with pytest.raises(TypeError, match="not an array"): | |
72 | mod.get_array_length(None) | |
73 | ||
74 | scal = pa.scalar(123) | |
75 | cast_scal = mod.cast_scalar(scal, pa.utf8()) | |
76 | assert cast_scal == pa.scalar("123") | |
77 | with pytest.raises(NotImplementedError, | |
78 | match="casting scalars of type int64 to type list"): | |
79 | mod.cast_scalar(scal, pa.list_(pa.int64())) | |
80 | ||
81 | ||
82 | @pytest.mark.cython | |
83 | def test_cython_api(tmpdir): | |
84 | """ | |
85 | Basic test for the Cython API. | |
86 | """ | |
87 | # Fail early if cython is not found | |
88 | import cython # noqa | |
89 | ||
90 | with tmpdir.as_cwd(): | |
91 | # Set up temporary workspace | |
92 | pyx_file = 'pyarrow_cython_example.pyx' | |
93 | shutil.copyfile(os.path.join(here, pyx_file), | |
94 | os.path.join(str(tmpdir), pyx_file)) | |
95 | # Create setup.py file | |
96 | setup_code = setup_template.format(pyx_file=pyx_file, | |
97 | compiler_opts=compiler_opts, | |
98 | test_ld_path=test_ld_path) | |
99 | with open('setup.py', 'w') as f: | |
100 | f.write(setup_code) | |
101 | ||
102 | # ARROW-2263: Make environment with this pyarrow/ package first on the | |
103 | # PYTHONPATH, for local dev environments | |
104 | subprocess_env = test_util.get_modified_env_with_pythonpath() | |
105 | ||
106 | # Compile extension module | |
107 | subprocess.check_call([sys.executable, 'setup.py', | |
108 | 'build_ext', '--inplace'], | |
109 | env=subprocess_env) | |
110 | ||
111 | # Check basic functionality | |
112 | orig_path = sys.path[:] | |
113 | sys.path.insert(0, str(tmpdir)) | |
114 | try: | |
115 | mod = __import__('pyarrow_cython_example') | |
116 | check_cython_example_module(mod) | |
117 | finally: | |
118 | sys.path = orig_path | |
119 | ||
120 | # Check the extension module is loadable from a subprocess without | |
121 | # pyarrow imported first. | |
122 | code = """if 1: | |
123 | import sys | |
124 | ||
125 | mod = __import__({mod_name!r}) | |
126 | arr = mod.make_null_array(5) | |
127 | assert mod.get_array_length(arr) == 5 | |
128 | assert arr.null_count == 5 | |
129 | """.format(mod_name='pyarrow_cython_example') | |
130 | ||
131 | if sys.platform == 'win32': | |
132 | delim, var = ';', 'PATH' | |
133 | else: | |
134 | delim, var = ':', 'LD_LIBRARY_PATH' | |
135 | ||
136 | subprocess_env[var] = delim.join( | |
137 | pa.get_library_dirs() + [subprocess_env.get(var, '')] | |
138 | ) | |
139 | ||
140 | subprocess.check_call([sys.executable, '-c', code], | |
141 | stdout=subprocess.PIPE, | |
142 | env=subprocess_env) | |
143 | ||
144 | ||
145 | @pytest.mark.cython | |
146 | def test_visit_strings(tmpdir): | |
147 | with tmpdir.as_cwd(): | |
148 | # Set up temporary workspace | |
149 | pyx_file = 'bound_function_visit_strings.pyx' | |
150 | shutil.copyfile(os.path.join(here, pyx_file), | |
151 | os.path.join(str(tmpdir), pyx_file)) | |
152 | # Create setup.py file | |
153 | setup_code = setup_template.format(pyx_file=pyx_file, | |
154 | compiler_opts=compiler_opts, | |
155 | test_ld_path=test_ld_path) | |
156 | with open('setup.py', 'w') as f: | |
157 | f.write(setup_code) | |
158 | ||
159 | subprocess_env = test_util.get_modified_env_with_pythonpath() | |
160 | ||
161 | # Compile extension module | |
162 | subprocess.check_call([sys.executable, 'setup.py', | |
163 | 'build_ext', '--inplace'], | |
164 | env=subprocess_env) | |
165 | ||
166 | sys.path.insert(0, str(tmpdir)) | |
167 | mod = __import__('bound_function_visit_strings') | |
168 | ||
169 | strings = ['a', 'b', 'c'] | |
170 | visited = [] | |
171 | mod._visit_strings(strings, visited.append) | |
172 | ||
173 | assert visited == strings | |
174 | ||
175 | with pytest.raises(ValueError, match="wtf"): | |
176 | def raise_on_b(s): | |
177 | if s == 'b': | |
178 | raise ValueError('wtf') | |
179 | ||
180 | mod._visit_strings(strings, raise_on_b) |