]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/python/pyarrow/tests/util.py
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / python / pyarrow / tests / util.py
CommitLineData
1d09f67e
TL
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""
19Utility functions for testing
20"""
21
22import contextlib
23import decimal
24import gc
25import numpy as np
26import os
27import random
28import signal
29import string
30import subprocess
31import sys
32
33import pytest
34
35import pyarrow as pa
36import pyarrow.fs
37
38
39def randsign():
40 """Randomly choose either 1 or -1.
41
42 Returns
43 -------
44 sign : int
45 """
46 return random.choice((-1, 1))
47
48
49@contextlib.contextmanager
50def random_seed(seed):
51 """Set the random seed inside of a context manager.
52
53 Parameters
54 ----------
55 seed : int
56 The seed to set
57
58 Notes
59 -----
60 This function is useful when you want to set a random seed but not affect
61 the random state of other functions using the random module.
62 """
63 original_state = random.getstate()
64 random.seed(seed)
65 try:
66 yield
67 finally:
68 random.setstate(original_state)
69
70
71def randdecimal(precision, scale):
72 """Generate a random decimal value with specified precision and scale.
73
74 Parameters
75 ----------
76 precision : int
77 The maximum number of digits to generate. Must be an integer between 1
78 and 38 inclusive.
79 scale : int
80 The maximum number of digits following the decimal point. Must be an
81 integer greater than or equal to 0.
82
83 Returns
84 -------
85 decimal_value : decimal.Decimal
86 A random decimal.Decimal object with the specified precision and scale.
87 """
88 assert 1 <= precision <= 38, 'precision must be between 1 and 38 inclusive'
89 if scale < 0:
90 raise ValueError(
91 'randdecimal does not yet support generating decimals with '
92 'negative scale'
93 )
94 max_whole_value = 10 ** (precision - scale) - 1
95 whole = random.randint(-max_whole_value, max_whole_value)
96
97 if not scale:
98 return decimal.Decimal(whole)
99
100 max_fractional_value = 10 ** scale - 1
101 fractional = random.randint(0, max_fractional_value)
102
103 return decimal.Decimal(
104 '{}.{}'.format(whole, str(fractional).rjust(scale, '0'))
105 )
106
107
108def random_ascii(length):
109 return bytes(np.random.randint(65, 123, size=length, dtype='i1'))
110
111
112def rands(nchars):
113 """
114 Generate one random string.
115 """
116 RANDS_CHARS = np.array(
117 list(string.ascii_letters + string.digits), dtype=(np.str_, 1))
118 return "".join(np.random.choice(RANDS_CHARS, nchars))
119
120
121def make_dataframe():
122 import pandas as pd
123
124 N = 30
125 df = pd.DataFrame(
126 {col: np.random.randn(N) for col in string.ascii_uppercase[:4]},
127 index=pd.Index([rands(10) for _ in range(N)])
128 )
129 return df
130
131
132def memory_leak_check(f, metric='rss', threshold=1 << 17, iterations=10,
133 check_interval=1):
134 """
135 Execute the function and try to detect a clear memory leak either internal
136 to Arrow or caused by a reference counting problem in the Python binding
137 implementation. Raises exception if a leak detected
138
139 Parameters
140 ----------
141 f : callable
142 Function to invoke on each iteration
143 metric : {'rss', 'vms', 'shared'}, default 'rss'
144 Attribute of psutil.Process.memory_info to use for determining current
145 memory use
146 threshold : int, default 128K
147 Threshold in number of bytes to consider a leak
148 iterations : int, default 10
149 Total number of invocations of f
150 check_interval : int, default 1
151 Number of invocations of f in between each memory use check
152 """
153 import psutil
154 proc = psutil.Process()
155
156 def _get_use():
157 gc.collect()
158 return getattr(proc.memory_info(), metric)
159
160 baseline_use = _get_use()
161
162 def _leak_check():
163 current_use = _get_use()
164 if current_use - baseline_use > threshold:
165 raise Exception("Memory leak detected. "
166 "Departure from baseline {} after {} iterations"
167 .format(current_use - baseline_use, i))
168
169 for i in range(iterations):
170 f()
171 if i % check_interval == 0:
172 _leak_check()
173
174
175def get_modified_env_with_pythonpath():
176 # Prepend pyarrow root directory to PYTHONPATH
177 env = os.environ.copy()
178 existing_pythonpath = env.get('PYTHONPATH', '')
179
180 module_path = os.path.abspath(
181 os.path.dirname(os.path.dirname(pa.__file__)))
182
183 if existing_pythonpath:
184 new_pythonpath = os.pathsep.join((module_path, existing_pythonpath))
185 else:
186 new_pythonpath = module_path
187 env['PYTHONPATH'] = new_pythonpath
188 return env
189
190
191def invoke_script(script_name, *args):
192 subprocess_env = get_modified_env_with_pythonpath()
193
194 dir_path = os.path.dirname(os.path.realpath(__file__))
195 python_file = os.path.join(dir_path, script_name)
196
197 cmd = [sys.executable, python_file]
198 cmd.extend(args)
199
200 subprocess.check_call(cmd, env=subprocess_env)
201
202
203@contextlib.contextmanager
204def changed_environ(name, value):
205 """
206 Temporarily set environment variable *name* to *value*.
207 """
208 orig_value = os.environ.get(name)
209 os.environ[name] = value
210 try:
211 yield
212 finally:
213 if orig_value is None:
214 del os.environ[name]
215 else:
216 os.environ[name] = orig_value
217
218
219@contextlib.contextmanager
220def change_cwd(path):
221 curdir = os.getcwd()
222 os.chdir(str(path))
223 try:
224 yield
225 finally:
226 os.chdir(curdir)
227
228
229@contextlib.contextmanager
230def disabled_gc():
231 gc.disable()
232 try:
233 yield
234 finally:
235 gc.enable()
236
237
238def _filesystem_uri(path):
239 # URIs on Windows must follow 'file:///C:...' or 'file:/C:...' patterns.
240 if os.name == 'nt':
241 uri = 'file:///{}'.format(path)
242 else:
243 uri = 'file://{}'.format(path)
244 return uri
245
246
247class FSProtocolClass:
248 def __init__(self, path):
249 self._path = path
250
251 def __fspath__(self):
252 return str(self._path)
253
254
255class ProxyHandler(pyarrow.fs.FileSystemHandler):
256 """
257 A dataset handler that proxies to an underlying filesystem. Useful
258 to partially wrap an existing filesystem with partial changes.
259 """
260
261 def __init__(self, fs):
262 self._fs = fs
263
264 def __eq__(self, other):
265 if isinstance(other, ProxyHandler):
266 return self._fs == other._fs
267 return NotImplemented
268
269 def __ne__(self, other):
270 if isinstance(other, ProxyHandler):
271 return self._fs != other._fs
272 return NotImplemented
273
274 def get_type_name(self):
275 return "proxy::" + self._fs.type_name
276
277 def normalize_path(self, path):
278 return self._fs.normalize_path(path)
279
280 def get_file_info(self, paths):
281 return self._fs.get_file_info(paths)
282
283 def get_file_info_selector(self, selector):
284 return self._fs.get_file_info(selector)
285
286 def create_dir(self, path, recursive):
287 return self._fs.create_dir(path, recursive=recursive)
288
289 def delete_dir(self, path):
290 return self._fs.delete_dir(path)
291
292 def delete_dir_contents(self, path):
293 return self._fs.delete_dir_contents(path)
294
295 def delete_root_dir_contents(self):
296 return self._fs.delete_dir_contents("", accept_root_dir=True)
297
298 def delete_file(self, path):
299 return self._fs.delete_file(path)
300
301 def move(self, src, dest):
302 return self._fs.move(src, dest)
303
304 def copy_file(self, src, dest):
305 return self._fs.copy_file(src, dest)
306
307 def open_input_stream(self, path):
308 return self._fs.open_input_stream(path)
309
310 def open_input_file(self, path):
311 return self._fs.open_input_file(path)
312
313 def open_output_stream(self, path, metadata):
314 return self._fs.open_output_stream(path, metadata=metadata)
315
316 def open_append_stream(self, path, metadata):
317 return self._fs.open_append_stream(path, metadata=metadata)
318
319
320def get_raise_signal():
321 if sys.version_info >= (3, 8):
322 return signal.raise_signal
323 elif os.name == 'nt':
324 # On Windows, os.kill() doesn't actually send a signal,
325 # it just terminates the process with the given exit code.
326 pytest.skip("test requires Python 3.8+ on Windows")
327 else:
328 # On Unix, emulate raise_signal() with os.kill().
329 def raise_signal(signum):
330 os.kill(os.getpid(), signum)
331 return raise_signal