]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | # Licensed to the Apache Software Foundation (ASF) under one |
2 | # or more contributor license agreements. See the NOTICE file | |
3 | # distributed with this work for additional information | |
4 | # regarding copyright ownership. The ASF licenses this file | |
5 | # to you under the Apache License, Version 2.0 (the | |
6 | # "License"); you may not use this file except in compliance | |
7 | # with the License. You may obtain a copy of the License at | |
8 | # | |
9 | # http://www.apache.org/licenses/LICENSE-2.0 | |
10 | # | |
11 | # Unless required by applicable law or agreed to in writing, | |
12 | # software distributed under the License is distributed on an | |
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | # KIND, either express or implied. See the License for the | |
15 | # specific language governing permissions and limitations | |
16 | # under the License. | |
17 | ||
18 | """ | |
19 | Utility functions for testing | |
20 | """ | |
21 | ||
22 | import contextlib | |
23 | import decimal | |
24 | import gc | |
25 | import numpy as np | |
26 | import os | |
27 | import random | |
28 | import signal | |
29 | import string | |
30 | import subprocess | |
31 | import sys | |
32 | ||
33 | import pytest | |
34 | ||
35 | import pyarrow as pa | |
36 | import pyarrow.fs | |
37 | ||
38 | ||
39 | def randsign(): | |
40 | """Randomly choose either 1 or -1. | |
41 | ||
42 | Returns | |
43 | ------- | |
44 | sign : int | |
45 | """ | |
46 | return random.choice((-1, 1)) | |
47 | ||
48 | ||
49 | @contextlib.contextmanager | |
50 | def random_seed(seed): | |
51 | """Set the random seed inside of a context manager. | |
52 | ||
53 | Parameters | |
54 | ---------- | |
55 | seed : int | |
56 | The seed to set | |
57 | ||
58 | Notes | |
59 | ----- | |
60 | This function is useful when you want to set a random seed but not affect | |
61 | the random state of other functions using the random module. | |
62 | """ | |
63 | original_state = random.getstate() | |
64 | random.seed(seed) | |
65 | try: | |
66 | yield | |
67 | finally: | |
68 | random.setstate(original_state) | |
69 | ||
70 | ||
71 | def randdecimal(precision, scale): | |
72 | """Generate a random decimal value with specified precision and scale. | |
73 | ||
74 | Parameters | |
75 | ---------- | |
76 | precision : int | |
77 | The maximum number of digits to generate. Must be an integer between 1 | |
78 | and 38 inclusive. | |
79 | scale : int | |
80 | The maximum number of digits following the decimal point. Must be an | |
81 | integer greater than or equal to 0. | |
82 | ||
83 | Returns | |
84 | ------- | |
85 | decimal_value : decimal.Decimal | |
86 | A random decimal.Decimal object with the specified precision and scale. | |
87 | """ | |
88 | assert 1 <= precision <= 38, 'precision must be between 1 and 38 inclusive' | |
89 | if scale < 0: | |
90 | raise ValueError( | |
91 | 'randdecimal does not yet support generating decimals with ' | |
92 | 'negative scale' | |
93 | ) | |
94 | max_whole_value = 10 ** (precision - scale) - 1 | |
95 | whole = random.randint(-max_whole_value, max_whole_value) | |
96 | ||
97 | if not scale: | |
98 | return decimal.Decimal(whole) | |
99 | ||
100 | max_fractional_value = 10 ** scale - 1 | |
101 | fractional = random.randint(0, max_fractional_value) | |
102 | ||
103 | return decimal.Decimal( | |
104 | '{}.{}'.format(whole, str(fractional).rjust(scale, '0')) | |
105 | ) | |
106 | ||
107 | ||
108 | def random_ascii(length): | |
109 | return bytes(np.random.randint(65, 123, size=length, dtype='i1')) | |
110 | ||
111 | ||
112 | def rands(nchars): | |
113 | """ | |
114 | Generate one random string. | |
115 | """ | |
116 | RANDS_CHARS = np.array( | |
117 | list(string.ascii_letters + string.digits), dtype=(np.str_, 1)) | |
118 | return "".join(np.random.choice(RANDS_CHARS, nchars)) | |
119 | ||
120 | ||
121 | def make_dataframe(): | |
122 | import pandas as pd | |
123 | ||
124 | N = 30 | |
125 | df = pd.DataFrame( | |
126 | {col: np.random.randn(N) for col in string.ascii_uppercase[:4]}, | |
127 | index=pd.Index([rands(10) for _ in range(N)]) | |
128 | ) | |
129 | return df | |
130 | ||
131 | ||
132 | def memory_leak_check(f, metric='rss', threshold=1 << 17, iterations=10, | |
133 | check_interval=1): | |
134 | """ | |
135 | Execute the function and try to detect a clear memory leak either internal | |
136 | to Arrow or caused by a reference counting problem in the Python binding | |
137 | implementation. Raises exception if a leak detected | |
138 | ||
139 | Parameters | |
140 | ---------- | |
141 | f : callable | |
142 | Function to invoke on each iteration | |
143 | metric : {'rss', 'vms', 'shared'}, default 'rss' | |
144 | Attribute of psutil.Process.memory_info to use for determining current | |
145 | memory use | |
146 | threshold : int, default 128K | |
147 | Threshold in number of bytes to consider a leak | |
148 | iterations : int, default 10 | |
149 | Total number of invocations of f | |
150 | check_interval : int, default 1 | |
151 | Number of invocations of f in between each memory use check | |
152 | """ | |
153 | import psutil | |
154 | proc = psutil.Process() | |
155 | ||
156 | def _get_use(): | |
157 | gc.collect() | |
158 | return getattr(proc.memory_info(), metric) | |
159 | ||
160 | baseline_use = _get_use() | |
161 | ||
162 | def _leak_check(): | |
163 | current_use = _get_use() | |
164 | if current_use - baseline_use > threshold: | |
165 | raise Exception("Memory leak detected. " | |
166 | "Departure from baseline {} after {} iterations" | |
167 | .format(current_use - baseline_use, i)) | |
168 | ||
169 | for i in range(iterations): | |
170 | f() | |
171 | if i % check_interval == 0: | |
172 | _leak_check() | |
173 | ||
174 | ||
175 | def get_modified_env_with_pythonpath(): | |
176 | # Prepend pyarrow root directory to PYTHONPATH | |
177 | env = os.environ.copy() | |
178 | existing_pythonpath = env.get('PYTHONPATH', '') | |
179 | ||
180 | module_path = os.path.abspath( | |
181 | os.path.dirname(os.path.dirname(pa.__file__))) | |
182 | ||
183 | if existing_pythonpath: | |
184 | new_pythonpath = os.pathsep.join((module_path, existing_pythonpath)) | |
185 | else: | |
186 | new_pythonpath = module_path | |
187 | env['PYTHONPATH'] = new_pythonpath | |
188 | return env | |
189 | ||
190 | ||
191 | def invoke_script(script_name, *args): | |
192 | subprocess_env = get_modified_env_with_pythonpath() | |
193 | ||
194 | dir_path = os.path.dirname(os.path.realpath(__file__)) | |
195 | python_file = os.path.join(dir_path, script_name) | |
196 | ||
197 | cmd = [sys.executable, python_file] | |
198 | cmd.extend(args) | |
199 | ||
200 | subprocess.check_call(cmd, env=subprocess_env) | |
201 | ||
202 | ||
203 | @contextlib.contextmanager | |
204 | def changed_environ(name, value): | |
205 | """ | |
206 | Temporarily set environment variable *name* to *value*. | |
207 | """ | |
208 | orig_value = os.environ.get(name) | |
209 | os.environ[name] = value | |
210 | try: | |
211 | yield | |
212 | finally: | |
213 | if orig_value is None: | |
214 | del os.environ[name] | |
215 | else: | |
216 | os.environ[name] = orig_value | |
217 | ||
218 | ||
219 | @contextlib.contextmanager | |
220 | def change_cwd(path): | |
221 | curdir = os.getcwd() | |
222 | os.chdir(str(path)) | |
223 | try: | |
224 | yield | |
225 | finally: | |
226 | os.chdir(curdir) | |
227 | ||
228 | ||
229 | @contextlib.contextmanager | |
230 | def disabled_gc(): | |
231 | gc.disable() | |
232 | try: | |
233 | yield | |
234 | finally: | |
235 | gc.enable() | |
236 | ||
237 | ||
238 | def _filesystem_uri(path): | |
239 | # URIs on Windows must follow 'file:///C:...' or 'file:/C:...' patterns. | |
240 | if os.name == 'nt': | |
241 | uri = 'file:///{}'.format(path) | |
242 | else: | |
243 | uri = 'file://{}'.format(path) | |
244 | return uri | |
245 | ||
246 | ||
247 | class FSProtocolClass: | |
248 | def __init__(self, path): | |
249 | self._path = path | |
250 | ||
251 | def __fspath__(self): | |
252 | return str(self._path) | |
253 | ||
254 | ||
255 | class ProxyHandler(pyarrow.fs.FileSystemHandler): | |
256 | """ | |
257 | A dataset handler that proxies to an underlying filesystem. Useful | |
258 | to partially wrap an existing filesystem with partial changes. | |
259 | """ | |
260 | ||
261 | def __init__(self, fs): | |
262 | self._fs = fs | |
263 | ||
264 | def __eq__(self, other): | |
265 | if isinstance(other, ProxyHandler): | |
266 | return self._fs == other._fs | |
267 | return NotImplemented | |
268 | ||
269 | def __ne__(self, other): | |
270 | if isinstance(other, ProxyHandler): | |
271 | return self._fs != other._fs | |
272 | return NotImplemented | |
273 | ||
274 | def get_type_name(self): | |
275 | return "proxy::" + self._fs.type_name | |
276 | ||
277 | def normalize_path(self, path): | |
278 | return self._fs.normalize_path(path) | |
279 | ||
280 | def get_file_info(self, paths): | |
281 | return self._fs.get_file_info(paths) | |
282 | ||
283 | def get_file_info_selector(self, selector): | |
284 | return self._fs.get_file_info(selector) | |
285 | ||
286 | def create_dir(self, path, recursive): | |
287 | return self._fs.create_dir(path, recursive=recursive) | |
288 | ||
289 | def delete_dir(self, path): | |
290 | return self._fs.delete_dir(path) | |
291 | ||
292 | def delete_dir_contents(self, path): | |
293 | return self._fs.delete_dir_contents(path) | |
294 | ||
295 | def delete_root_dir_contents(self): | |
296 | return self._fs.delete_dir_contents("", accept_root_dir=True) | |
297 | ||
298 | def delete_file(self, path): | |
299 | return self._fs.delete_file(path) | |
300 | ||
301 | def move(self, src, dest): | |
302 | return self._fs.move(src, dest) | |
303 | ||
304 | def copy_file(self, src, dest): | |
305 | return self._fs.copy_file(src, dest) | |
306 | ||
307 | def open_input_stream(self, path): | |
308 | return self._fs.open_input_stream(path) | |
309 | ||
310 | def open_input_file(self, path): | |
311 | return self._fs.open_input_file(path) | |
312 | ||
313 | def open_output_stream(self, path, metadata): | |
314 | return self._fs.open_output_stream(path, metadata=metadata) | |
315 | ||
316 | def open_append_stream(self, path, metadata): | |
317 | return self._fs.open_append_stream(path, metadata=metadata) | |
318 | ||
319 | ||
320 | def get_raise_signal(): | |
321 | if sys.version_info >= (3, 8): | |
322 | return signal.raise_signal | |
323 | elif os.name == 'nt': | |
324 | # On Windows, os.kill() doesn't actually send a signal, | |
325 | # it just terminates the process with the given exit code. | |
326 | pytest.skip("test requires Python 3.8+ on Windows") | |
327 | else: | |
328 | # On Unix, emulate raise_signal() with os.kill(). | |
329 | def raise_signal(signum): | |
330 | os.kill(os.getpid(), signum) | |
331 | return raise_signal |