]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include <memory> | |
19 | #include <sstream> | |
20 | #include <utility> | |
21 | ||
22 | #include "arrow/python/extension_type.h" | |
23 | #include "arrow/python/helpers.h" | |
24 | #include "arrow/python/pyarrow.h" | |
25 | #include "arrow/util/checked_cast.h" | |
26 | #include "arrow/util/logging.h" | |
27 | ||
28 | namespace arrow { | |
29 | ||
30 | using internal::checked_cast; | |
31 | ||
32 | namespace py { | |
33 | ||
34 | namespace { | |
35 | ||
36 | // Serialize a Python ExtensionType instance | |
37 | Status SerializeExtInstance(PyObject* type_instance, std::string* out) { | |
38 | OwnedRef res( | |
39 | cpp_PyObject_CallMethod(type_instance, "__arrow_ext_serialize__", nullptr)); | |
40 | if (!res) { | |
41 | return ConvertPyError(); | |
42 | } | |
43 | if (!PyBytes_Check(res.obj())) { | |
44 | return Status::TypeError( | |
45 | "__arrow_ext_serialize__ should return bytes object, " | |
46 | "got ", | |
47 | internal::PyObject_StdStringRepr(res.obj())); | |
48 | } | |
49 | *out = internal::PyBytes_AsStdString(res.obj()); | |
50 | return Status::OK(); | |
51 | } | |
52 | ||
53 | // Deserialize a Python ExtensionType instance | |
54 | PyObject* DeserializeExtInstance(PyObject* type_class, | |
55 | std::shared_ptr<DataType> storage_type, | |
56 | const std::string& serialized_data) { | |
57 | OwnedRef storage_ref(wrap_data_type(storage_type)); | |
58 | if (!storage_ref) { | |
59 | return nullptr; | |
60 | } | |
61 | OwnedRef data_ref(PyBytes_FromStringAndSize( | |
62 | serialized_data.data(), static_cast<Py_ssize_t>(serialized_data.size()))); | |
63 | if (!data_ref) { | |
64 | return nullptr; | |
65 | } | |
66 | ||
67 | return cpp_PyObject_CallMethod(type_class, "__arrow_ext_deserialize__", "OO", | |
68 | storage_ref.obj(), data_ref.obj()); | |
69 | } | |
70 | ||
71 | } // namespace | |
72 | ||
73 | static const char* kExtensionName = "arrow.py_extension_type"; | |
74 | ||
75 | std::string PyExtensionType::ToString() const { | |
76 | PyAcquireGIL lock; | |
77 | ||
78 | std::stringstream ss; | |
79 | OwnedRef instance(GetInstance()); | |
80 | ss << "extension<" << this->extension_name() << "<" << Py_TYPE(instance.obj())->tp_name | |
81 | << ">>"; | |
82 | return ss.str(); | |
83 | } | |
84 | ||
85 | PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ, | |
86 | PyObject* inst) | |
87 | : ExtensionType(storage_type), | |
88 | extension_name_(kExtensionName), | |
89 | type_class_(typ), | |
90 | type_instance_(inst) {} | |
91 | ||
92 | PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type, | |
93 | std::string extension_name, PyObject* typ, | |
94 | PyObject* inst) | |
95 | : ExtensionType(storage_type), | |
96 | extension_name_(std::move(extension_name)), | |
97 | type_class_(typ), | |
98 | type_instance_(inst) {} | |
99 | ||
100 | bool PyExtensionType::ExtensionEquals(const ExtensionType& other) const { | |
101 | PyAcquireGIL lock; | |
102 | ||
103 | if (other.extension_name() != extension_name()) { | |
104 | return false; | |
105 | } | |
106 | const auto& other_ext = checked_cast<const PyExtensionType&>(other); | |
107 | int res = -1; | |
108 | if (!type_instance_) { | |
109 | if (other_ext.type_instance_) { | |
110 | return false; | |
111 | } | |
112 | // Compare Python types | |
113 | res = PyObject_RichCompareBool(type_class_.obj(), other_ext.type_class_.obj(), Py_EQ); | |
114 | } else { | |
115 | if (!other_ext.type_instance_) { | |
116 | return false; | |
117 | } | |
118 | // Compare Python instances | |
119 | OwnedRef left(GetInstance()); | |
120 | OwnedRef right(other_ext.GetInstance()); | |
121 | if (!left || !right) { | |
122 | goto error; | |
123 | } | |
124 | res = PyObject_RichCompareBool(left.obj(), right.obj(), Py_EQ); | |
125 | } | |
126 | if (res == -1) { | |
127 | goto error; | |
128 | } | |
129 | return res == 1; | |
130 | ||
131 | error: | |
132 | // Cannot propagate error | |
133 | PyErr_WriteUnraisable(nullptr); | |
134 | return false; | |
135 | } | |
136 | ||
137 | std::shared_ptr<Array> PyExtensionType::MakeArray(std::shared_ptr<ArrayData> data) const { | |
138 | DCHECK_EQ(data->type->id(), Type::EXTENSION); | |
139 | return std::make_shared<ExtensionArray>(data); | |
140 | } | |
141 | ||
142 | std::string PyExtensionType::Serialize() const { | |
143 | DCHECK(type_instance_); | |
144 | return serialized_; | |
145 | } | |
146 | ||
147 | Result<std::shared_ptr<DataType>> PyExtensionType::Deserialize( | |
148 | std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const { | |
149 | PyAcquireGIL lock; | |
150 | ||
151 | if (import_pyarrow()) { | |
152 | return ConvertPyError(); | |
153 | } | |
154 | OwnedRef res(DeserializeExtInstance(type_class_.obj(), storage_type, serialized_data)); | |
155 | if (!res) { | |
156 | return ConvertPyError(); | |
157 | } | |
158 | return unwrap_data_type(res.obj()); | |
159 | } | |
160 | ||
161 | PyObject* PyExtensionType::GetInstance() const { | |
162 | if (!type_instance_) { | |
163 | PyErr_SetString(PyExc_TypeError, "Not an instance"); | |
164 | return nullptr; | |
165 | } | |
166 | DCHECK(PyWeakref_CheckRef(type_instance_.obj())); | |
167 | PyObject* inst = PyWeakref_GET_OBJECT(type_instance_.obj()); | |
168 | if (inst != Py_None) { | |
169 | // Cached instance still alive | |
170 | Py_INCREF(inst); | |
171 | return inst; | |
172 | } else { | |
173 | // Must reconstruct from serialized form | |
174 | // XXX cache again? | |
175 | return DeserializeExtInstance(type_class_.obj(), storage_type_, serialized_); | |
176 | } | |
177 | } | |
178 | ||
179 | Status PyExtensionType::SetInstance(PyObject* inst) const { | |
180 | // Check we have the right type | |
181 | PyObject* typ = reinterpret_cast<PyObject*>(Py_TYPE(inst)); | |
182 | if (typ != type_class_.obj()) { | |
183 | return Status::TypeError("Unexpected Python ExtensionType class ", | |
184 | internal::PyObject_StdStringRepr(typ), " expected ", | |
185 | internal::PyObject_StdStringRepr(type_class_.obj())); | |
186 | } | |
187 | ||
188 | PyObject* wr = PyWeakref_NewRef(inst, nullptr); | |
189 | if (wr == NULL) { | |
190 | return ConvertPyError(); | |
191 | } | |
192 | type_instance_.reset(wr); | |
193 | return SerializeExtInstance(inst, &serialized_); | |
194 | } | |
195 | ||
196 | Status PyExtensionType::FromClass(const std::shared_ptr<DataType> storage_type, | |
197 | const std::string extension_name, PyObject* typ, | |
198 | std::shared_ptr<ExtensionType>* out) { | |
199 | Py_INCREF(typ); | |
200 | out->reset(new PyExtensionType(storage_type, std::move(extension_name), typ)); | |
201 | return Status::OK(); | |
202 | } | |
203 | ||
204 | Status RegisterPyExtensionType(const std::shared_ptr<DataType>& type) { | |
205 | DCHECK_EQ(type->id(), Type::EXTENSION); | |
206 | auto ext_type = std::dynamic_pointer_cast<ExtensionType>(type); | |
207 | return RegisterExtensionType(ext_type); | |
208 | } | |
209 | ||
210 | Status UnregisterPyExtensionType(const std::string& type_name) { | |
211 | return UnregisterExtensionType(type_name); | |
212 | } | |
213 | ||
214 | std::string PyExtensionName() { return kExtensionName; } | |
215 | ||
216 | } // namespace py | |
217 | } // namespace arrow |