]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | # |
2 | # Licensed to the Apache Software Foundation (ASF) under one | |
3 | # or more contributor license agreements. See the NOTICE file | |
4 | # distributed with this work for additional information | |
5 | # regarding copyright ownership. The ASF licenses this file | |
6 | # to you under the Apache License, Version 2.0 (the | |
7 | # "License"); you may not use this file except in compliance | |
8 | # with the License. You may obtain a copy of the License at | |
9 | # | |
10 | # http://www.apache.org/licenses/LICENSE-2.0 | |
11 | # | |
12 | # Unless required by applicable law or agreed to in writing, | |
13 | # software distributed under the License is distributed on an | |
14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | # KIND, either express or implied. See the License for the | |
16 | # specific language governing permissions and limitations | |
17 | # under the License. | |
18 | # | |
19 | ||
20 | from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory | |
21 | from struct import pack, unpack | |
22 | ||
23 | ||
24 | class TBinaryProtocol(TProtocolBase): | |
25 | """Binary implementation of the Thrift protocol driver.""" | |
26 | ||
27 | # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be | |
28 | # positive, converting this into a long. If we hardcode the int value | |
29 | # instead it'll stay in 32 bit-land. | |
30 | ||
31 | # VERSION_MASK = 0xffff0000 | |
32 | VERSION_MASK = -65536 | |
33 | ||
34 | # VERSION_1 = 0x80010000 | |
35 | VERSION_1 = -2147418112 | |
36 | ||
37 | TYPE_MASK = 0x000000ff | |
38 | ||
39 | def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs): | |
40 | TProtocolBase.__init__(self, trans) | |
41 | self.strictRead = strictRead | |
42 | self.strictWrite = strictWrite | |
43 | self.string_length_limit = kwargs.get('string_length_limit', None) | |
44 | self.container_length_limit = kwargs.get('container_length_limit', None) | |
45 | ||
46 | def _check_string_length(self, length): | |
47 | self._check_length(self.string_length_limit, length) | |
48 | ||
49 | def _check_container_length(self, length): | |
50 | self._check_length(self.container_length_limit, length) | |
51 | ||
52 | def writeMessageBegin(self, name, type, seqid): | |
53 | if self.strictWrite: | |
54 | self.writeI32(TBinaryProtocol.VERSION_1 | type) | |
55 | self.writeString(name) | |
56 | self.writeI32(seqid) | |
57 | else: | |
58 | self.writeString(name) | |
59 | self.writeByte(type) | |
60 | self.writeI32(seqid) | |
61 | ||
62 | def writeMessageEnd(self): | |
63 | pass | |
64 | ||
65 | def writeStructBegin(self, name): | |
66 | pass | |
67 | ||
68 | def writeStructEnd(self): | |
69 | pass | |
70 | ||
71 | def writeFieldBegin(self, name, type, id): | |
72 | self.writeByte(type) | |
73 | self.writeI16(id) | |
74 | ||
75 | def writeFieldEnd(self): | |
76 | pass | |
77 | ||
78 | def writeFieldStop(self): | |
79 | self.writeByte(TType.STOP) | |
80 | ||
81 | def writeMapBegin(self, ktype, vtype, size): | |
82 | self.writeByte(ktype) | |
83 | self.writeByte(vtype) | |
84 | self.writeI32(size) | |
85 | ||
86 | def writeMapEnd(self): | |
87 | pass | |
88 | ||
89 | def writeListBegin(self, etype, size): | |
90 | self.writeByte(etype) | |
91 | self.writeI32(size) | |
92 | ||
93 | def writeListEnd(self): | |
94 | pass | |
95 | ||
96 | def writeSetBegin(self, etype, size): | |
97 | self.writeByte(etype) | |
98 | self.writeI32(size) | |
99 | ||
100 | def writeSetEnd(self): | |
101 | pass | |
102 | ||
103 | def writeBool(self, bool): | |
104 | if bool: | |
105 | self.writeByte(1) | |
106 | else: | |
107 | self.writeByte(0) | |
108 | ||
109 | def writeByte(self, byte): | |
110 | buff = pack("!b", byte) | |
111 | self.trans.write(buff) | |
112 | ||
113 | def writeI16(self, i16): | |
114 | buff = pack("!h", i16) | |
115 | self.trans.write(buff) | |
116 | ||
117 | def writeI32(self, i32): | |
118 | buff = pack("!i", i32) | |
119 | self.trans.write(buff) | |
120 | ||
121 | def writeI64(self, i64): | |
122 | buff = pack("!q", i64) | |
123 | self.trans.write(buff) | |
124 | ||
125 | def writeDouble(self, dub): | |
126 | buff = pack("!d", dub) | |
127 | self.trans.write(buff) | |
128 | ||
129 | def writeBinary(self, str): | |
130 | self.writeI32(len(str)) | |
131 | self.trans.write(str) | |
132 | ||
133 | def readMessageBegin(self): | |
134 | sz = self.readI32() | |
135 | if sz < 0: | |
136 | version = sz & TBinaryProtocol.VERSION_MASK | |
137 | if version != TBinaryProtocol.VERSION_1: | |
138 | raise TProtocolException( | |
139 | type=TProtocolException.BAD_VERSION, | |
140 | message='Bad version in readMessageBegin: %d' % (sz)) | |
141 | type = sz & TBinaryProtocol.TYPE_MASK | |
142 | name = self.readString() | |
143 | seqid = self.readI32() | |
144 | else: | |
145 | if self.strictRead: | |
146 | raise TProtocolException(type=TProtocolException.BAD_VERSION, | |
147 | message='No protocol version header') | |
148 | name = self.trans.readAll(sz) | |
149 | type = self.readByte() | |
150 | seqid = self.readI32() | |
151 | return (name, type, seqid) | |
152 | ||
153 | def readMessageEnd(self): | |
154 | pass | |
155 | ||
156 | def readStructBegin(self): | |
157 | pass | |
158 | ||
159 | def readStructEnd(self): | |
160 | pass | |
161 | ||
162 | def readFieldBegin(self): | |
163 | type = self.readByte() | |
164 | if type == TType.STOP: | |
165 | return (None, type, 0) | |
166 | id = self.readI16() | |
167 | return (None, type, id) | |
168 | ||
169 | def readFieldEnd(self): | |
170 | pass | |
171 | ||
172 | def readMapBegin(self): | |
173 | ktype = self.readByte() | |
174 | vtype = self.readByte() | |
175 | size = self.readI32() | |
176 | self._check_container_length(size) | |
177 | return (ktype, vtype, size) | |
178 | ||
179 | def readMapEnd(self): | |
180 | pass | |
181 | ||
182 | def readListBegin(self): | |
183 | etype = self.readByte() | |
184 | size = self.readI32() | |
185 | self._check_container_length(size) | |
186 | return (etype, size) | |
187 | ||
188 | def readListEnd(self): | |
189 | pass | |
190 | ||
191 | def readSetBegin(self): | |
192 | etype = self.readByte() | |
193 | size = self.readI32() | |
194 | self._check_container_length(size) | |
195 | return (etype, size) | |
196 | ||
197 | def readSetEnd(self): | |
198 | pass | |
199 | ||
200 | def readBool(self): | |
201 | byte = self.readByte() | |
202 | if byte == 0: | |
203 | return False | |
204 | return True | |
205 | ||
206 | def readByte(self): | |
207 | buff = self.trans.readAll(1) | |
208 | val, = unpack('!b', buff) | |
209 | return val | |
210 | ||
211 | def readI16(self): | |
212 | buff = self.trans.readAll(2) | |
213 | val, = unpack('!h', buff) | |
214 | return val | |
215 | ||
216 | def readI32(self): | |
217 | buff = self.trans.readAll(4) | |
218 | val, = unpack('!i', buff) | |
219 | return val | |
220 | ||
221 | def readI64(self): | |
222 | buff = self.trans.readAll(8) | |
223 | val, = unpack('!q', buff) | |
224 | return val | |
225 | ||
226 | def readDouble(self): | |
227 | buff = self.trans.readAll(8) | |
228 | val, = unpack('!d', buff) | |
229 | return val | |
230 | ||
231 | def readBinary(self): | |
232 | size = self.readI32() | |
233 | self._check_string_length(size) | |
234 | s = self.trans.readAll(size) | |
235 | return s | |
236 | ||
237 | ||
238 | class TBinaryProtocolFactory(TProtocolFactory): | |
239 | def __init__(self, strictRead=False, strictWrite=True, **kwargs): | |
240 | self.strictRead = strictRead | |
241 | self.strictWrite = strictWrite | |
242 | self.string_length_limit = kwargs.get('string_length_limit', None) | |
243 | self.container_length_limit = kwargs.get('container_length_limit', None) | |
244 | ||
245 | def getProtocol(self, trans): | |
246 | prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite, | |
247 | string_length_limit=self.string_length_limit, | |
248 | container_length_limit=self.container_length_limit) | |
249 | return prot | |
250 | ||
251 | ||
252 | class TBinaryProtocolAccelerated(TBinaryProtocol): | |
253 | """C-Accelerated version of TBinaryProtocol. | |
254 | ||
255 | This class does not override any of TBinaryProtocol's methods, | |
256 | but the generated code recognizes it directly and will call into | |
257 | our C module to do the encoding, bypassing this object entirely. | |
258 | We inherit from TBinaryProtocol so that the normal TBinaryProtocol | |
259 | encoding can happen if the fastbinary module doesn't work for some | |
260 | reason. (TODO(dreiss): Make this happen sanely in more cases.) | |
261 | To disable this behavior, pass fallback=False constructor argument. | |
262 | ||
263 | In order to take advantage of the C module, just use | |
264 | TBinaryProtocolAccelerated instead of TBinaryProtocol. | |
265 | ||
266 | NOTE: This code was contributed by an external developer. | |
267 | The internal Thrift team has reviewed and tested it, | |
268 | but we cannot guarantee that it is production-ready. | |
269 | Please feel free to report bugs and/or success stories | |
270 | to the public mailing list. | |
271 | """ | |
272 | pass | |
273 | ||
274 | def __init__(self, *args, **kwargs): | |
275 | fallback = kwargs.pop('fallback', True) | |
276 | super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs) | |
277 | try: | |
278 | from thrift.protocol import fastbinary | |
279 | except ImportError: | |
280 | if not fallback: | |
281 | raise | |
282 | else: | |
283 | self._fast_decode = fastbinary.decode_binary | |
284 | self._fast_encode = fastbinary.encode_binary | |
285 | ||
286 | ||
287 | class TBinaryProtocolAcceleratedFactory(TProtocolFactory): | |
288 | def __init__(self, | |
289 | string_length_limit=None, | |
290 | container_length_limit=None, | |
291 | fallback=True): | |
292 | self.string_length_limit = string_length_limit | |
293 | self.container_length_limit = container_length_limit | |
294 | self._fallback = fallback | |
295 | ||
296 | def getProtocol(self, trans): | |
297 | return TBinaryProtocolAccelerated( | |
298 | trans, | |
299 | string_length_limit=self.string_length_limit, | |
300 | container_length_limit=self.container_length_limit, | |
301 | fallback=self._fallback) |