]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/python/extension_type.cc
bump version to 18.2.2-pve1
[ceph.git] / ceph / src / arrow / cpp / src / arrow / python / extension_type.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include <memory>
19#include <sstream>
20#include <utility>
21
22#include "arrow/python/extension_type.h"
23#include "arrow/python/helpers.h"
24#include "arrow/python/pyarrow.h"
25#include "arrow/util/checked_cast.h"
26#include "arrow/util/logging.h"
27
28namespace arrow {
29
30using internal::checked_cast;
31
32namespace py {
33
34namespace {
35
36// Serialize a Python ExtensionType instance
37Status SerializeExtInstance(PyObject* type_instance, std::string* out) {
38 OwnedRef res(
39 cpp_PyObject_CallMethod(type_instance, "__arrow_ext_serialize__", nullptr));
40 if (!res) {
41 return ConvertPyError();
42 }
43 if (!PyBytes_Check(res.obj())) {
44 return Status::TypeError(
45 "__arrow_ext_serialize__ should return bytes object, "
46 "got ",
47 internal::PyObject_StdStringRepr(res.obj()));
48 }
49 *out = internal::PyBytes_AsStdString(res.obj());
50 return Status::OK();
51}
52
53// Deserialize a Python ExtensionType instance
54PyObject* DeserializeExtInstance(PyObject* type_class,
55 std::shared_ptr<DataType> storage_type,
56 const std::string& serialized_data) {
57 OwnedRef storage_ref(wrap_data_type(storage_type));
58 if (!storage_ref) {
59 return nullptr;
60 }
61 OwnedRef data_ref(PyBytes_FromStringAndSize(
62 serialized_data.data(), static_cast<Py_ssize_t>(serialized_data.size())));
63 if (!data_ref) {
64 return nullptr;
65 }
66
67 return cpp_PyObject_CallMethod(type_class, "__arrow_ext_deserialize__", "OO",
68 storage_ref.obj(), data_ref.obj());
69}
70
71} // namespace
72
73static const char* kExtensionName = "arrow.py_extension_type";
74
75std::string PyExtensionType::ToString() const {
76 PyAcquireGIL lock;
77
78 std::stringstream ss;
79 OwnedRef instance(GetInstance());
80 ss << "extension<" << this->extension_name() << "<" << Py_TYPE(instance.obj())->tp_name
81 << ">>";
82 return ss.str();
83}
84
85PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ,
86 PyObject* inst)
87 : ExtensionType(storage_type),
88 extension_name_(kExtensionName),
89 type_class_(typ),
90 type_instance_(inst) {}
91
92PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type,
93 std::string extension_name, PyObject* typ,
94 PyObject* inst)
95 : ExtensionType(storage_type),
96 extension_name_(std::move(extension_name)),
97 type_class_(typ),
98 type_instance_(inst) {}
99
100bool PyExtensionType::ExtensionEquals(const ExtensionType& other) const {
101 PyAcquireGIL lock;
102
103 if (other.extension_name() != extension_name()) {
104 return false;
105 }
106 const auto& other_ext = checked_cast<const PyExtensionType&>(other);
107 int res = -1;
108 if (!type_instance_) {
109 if (other_ext.type_instance_) {
110 return false;
111 }
112 // Compare Python types
113 res = PyObject_RichCompareBool(type_class_.obj(), other_ext.type_class_.obj(), Py_EQ);
114 } else {
115 if (!other_ext.type_instance_) {
116 return false;
117 }
118 // Compare Python instances
119 OwnedRef left(GetInstance());
120 OwnedRef right(other_ext.GetInstance());
121 if (!left || !right) {
122 goto error;
123 }
124 res = PyObject_RichCompareBool(left.obj(), right.obj(), Py_EQ);
125 }
126 if (res == -1) {
127 goto error;
128 }
129 return res == 1;
130
131error:
132 // Cannot propagate error
133 PyErr_WriteUnraisable(nullptr);
134 return false;
135}
136
137std::shared_ptr<Array> PyExtensionType::MakeArray(std::shared_ptr<ArrayData> data) const {
138 DCHECK_EQ(data->type->id(), Type::EXTENSION);
139 return std::make_shared<ExtensionArray>(data);
140}
141
142std::string PyExtensionType::Serialize() const {
143 DCHECK(type_instance_);
144 return serialized_;
145}
146
147Result<std::shared_ptr<DataType>> PyExtensionType::Deserialize(
148 std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
149 PyAcquireGIL lock;
150
151 if (import_pyarrow()) {
152 return ConvertPyError();
153 }
154 OwnedRef res(DeserializeExtInstance(type_class_.obj(), storage_type, serialized_data));
155 if (!res) {
156 return ConvertPyError();
157 }
158 return unwrap_data_type(res.obj());
159}
160
161PyObject* PyExtensionType::GetInstance() const {
162 if (!type_instance_) {
163 PyErr_SetString(PyExc_TypeError, "Not an instance");
164 return nullptr;
165 }
166 DCHECK(PyWeakref_CheckRef(type_instance_.obj()));
167 PyObject* inst = PyWeakref_GET_OBJECT(type_instance_.obj());
168 if (inst != Py_None) {
169 // Cached instance still alive
170 Py_INCREF(inst);
171 return inst;
172 } else {
173 // Must reconstruct from serialized form
174 // XXX cache again?
175 return DeserializeExtInstance(type_class_.obj(), storage_type_, serialized_);
176 }
177}
178
179Status PyExtensionType::SetInstance(PyObject* inst) const {
180 // Check we have the right type
181 PyObject* typ = reinterpret_cast<PyObject*>(Py_TYPE(inst));
182 if (typ != type_class_.obj()) {
183 return Status::TypeError("Unexpected Python ExtensionType class ",
184 internal::PyObject_StdStringRepr(typ), " expected ",
185 internal::PyObject_StdStringRepr(type_class_.obj()));
186 }
187
188 PyObject* wr = PyWeakref_NewRef(inst, nullptr);
189 if (wr == NULL) {
190 return ConvertPyError();
191 }
192 type_instance_.reset(wr);
193 return SerializeExtInstance(inst, &serialized_);
194}
195
196Status PyExtensionType::FromClass(const std::shared_ptr<DataType> storage_type,
197 const std::string extension_name, PyObject* typ,
198 std::shared_ptr<ExtensionType>* out) {
199 Py_INCREF(typ);
200 out->reset(new PyExtensionType(storage_type, std::move(extension_name), typ));
201 return Status::OK();
202}
203
204Status RegisterPyExtensionType(const std::shared_ptr<DataType>& type) {
205 DCHECK_EQ(type->id(), Type::EXTENSION);
206 auto ext_type = std::dynamic_pointer_cast<ExtensionType>(type);
207 return RegisterExtensionType(ext_type);
208}
209
210Status UnregisterPyExtensionType(const std::string& type_name) {
211 return UnregisterExtensionType(type_name);
212}
213
214std::string PyExtensionName() { return kExtensionName; }
215
216} // namespace py
217} // namespace arrow