2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
20 from .TProtocol
import TType
, TProtocolBase
, TProtocolException
, TProtocolFactory
, checkIntegerLimits
21 from struct
import pack
, unpack
23 from ..compat
import binary_to_str
, str_to_binary
25 __all__
= ['TCompactProtocol', 'TCompactProtocolFactory']
38 def make_helper(v_from
, container
):
40 def nested(self
, *args
, **kwargs
):
41 assert self
.state
in (v_from
, container
), (self
.state
, v_from
, container
)
42 return func(self
, *args
, **kwargs
)
47 writer
= make_helper(VALUE_WRITE
, CONTAINER_WRITE
)
48 reader
= make_helper(VALUE_READ
, CONTAINER_READ
)
51 def makeZigZag(n
, bits
):
52 checkIntegerLimits(n
, bits
)
53 return (n
<< 1) ^
(n
>> (bits
- 1))
57 return (n
>> 1) ^
-(n
& 1)
60 def writeVarint(trans
, n
):
61 assert n
>= 0, "Input to TCompactProtocol writeVarint cannot be negative!"
68 out
.append((n
& 0xff) |
0x80)
70 trans
.write(bytes(out
))
73 def readVarint(trans
):
79 result |
= (byte
& 0x7f) << shift
85 class CompactType(object):
102 TType
.STOP
: CompactType
.STOP
,
103 TType
.BOOL
: CompactType
.TRUE
, # used for collection
104 TType
.BYTE
: CompactType
.BYTE
,
105 TType
.I16
: CompactType
.I16
,
106 TType
.I32
: CompactType
.I32
,
107 TType
.I64
: CompactType
.I64
,
108 TType
.DOUBLE
: CompactType
.DOUBLE
,
109 TType
.STRING
: CompactType
.BINARY
,
110 TType
.STRUCT
: CompactType
.STRUCT
,
111 TType
.LIST
: CompactType
.LIST
,
112 TType
.SET
: CompactType
.SET
,
113 TType
.MAP
: CompactType
.MAP
,
117 for k
, v
in CTYPES
.items():
119 TTYPES
[CompactType
.FALSE
] = TType
.BOOL
124 class TCompactProtocol(TProtocolBase
):
125 """Compact implementation of the Thrift protocol driver."""
132 TYPE_SHIFT_AMOUNT
= 5
134 def __init__(self
, trans
,
135 string_length_limit
=None,
136 container_length_limit
=None):
137 TProtocolBase
.__init
__(self
, trans
)
140 self
.__bool
_fid
= None
141 self
.__bool
_value
= None
143 self
.__containers
= []
144 self
.string_length_limit
= string_length_limit
145 self
.container_length_limit
= container_length_limit
147 def _check_string_length(self
, length
):
148 self
._check
_length
(self
.string_length_limit
, length
)
150 def _check_container_length(self
, length
):
151 self
._check
_length
(self
.container_length_limit
, length
)
153 def __writeVarint(self
, n
):
154 writeVarint(self
.trans
, n
)
156 def writeMessageBegin(self
, name
, type, seqid
):
157 assert self
.state
== CLEAR
158 self
.__writeUByte
(self
.PROTOCOL_ID
)
159 self
.__writeUByte
(self
.VERSION |
(type << self
.TYPE_SHIFT_AMOUNT
))
160 # The sequence id is a signed 32-bit integer but the compact protocol
161 # writes this out as a "var int" which is always positive, and attempting
162 # to write a negative number results in an infinite loop, so we may
163 # need to do some conversion here...
166 tseqid
= 2147483648 + (2147483648 + tseqid
)
167 self
.__writeVarint
(tseqid
)
168 self
.__writeBinary
(str_to_binary(name
))
169 self
.state
= VALUE_WRITE
171 def writeMessageEnd(self
):
172 assert self
.state
== VALUE_WRITE
175 def writeStructBegin(self
, name
):
176 assert self
.state
in (CLEAR
, CONTAINER_WRITE
, VALUE_WRITE
), self
.state
177 self
.__structs
.append((self
.state
, self
.__last
_fid
))
178 self
.state
= FIELD_WRITE
181 def writeStructEnd(self
):
182 assert self
.state
== FIELD_WRITE
183 self
.state
, self
.__last
_fid
= self
.__structs
.pop()
185 def writeFieldStop(self
):
188 def __writeFieldHeader(self
, type, fid
):
189 delta
= fid
- self
.__last
_fid
191 self
.__writeUByte
(delta
<< 4 |
type)
193 self
.__writeByte
(type)
195 self
.__last
_fid
= fid
197 def writeFieldBegin(self
, name
, type, fid
):
198 assert self
.state
== FIELD_WRITE
, self
.state
199 if type == TType
.BOOL
:
200 self
.state
= BOOL_WRITE
201 self
.__bool
_fid
= fid
203 self
.state
= VALUE_WRITE
204 self
.__writeFieldHeader
(CTYPES
[type], fid
)
206 def writeFieldEnd(self
):
207 assert self
.state
in (VALUE_WRITE
, BOOL_WRITE
), self
.state
208 self
.state
= FIELD_WRITE
210 def __writeUByte(self
, byte
):
211 self
.trans
.write(pack('!B', byte
))
213 def __writeByte(self
, byte
):
214 self
.trans
.write(pack('!b', byte
))
216 def __writeI16(self
, i16
):
217 self
.__writeVarint
(makeZigZag(i16
, 16))
219 def __writeSize(self
, i32
):
220 self
.__writeVarint
(i32
)
222 def writeCollectionBegin(self
, etype
, size
):
223 assert self
.state
in (VALUE_WRITE
, CONTAINER_WRITE
), self
.state
225 self
.__writeUByte
(size
<< 4 | CTYPES
[etype
])
227 self
.__writeUByte
(0xf0 | CTYPES
[etype
])
228 self
.__writeSize
(size
)
229 self
.__containers
.append(self
.state
)
230 self
.state
= CONTAINER_WRITE
231 writeSetBegin
= writeCollectionBegin
232 writeListBegin
= writeCollectionBegin
234 def writeMapBegin(self
, ktype
, vtype
, size
):
235 assert self
.state
in (VALUE_WRITE
, CONTAINER_WRITE
), self
.state
239 self
.__writeSize
(size
)
240 self
.__writeUByte
(CTYPES
[ktype
] << 4 | CTYPES
[vtype
])
241 self
.__containers
.append(self
.state
)
242 self
.state
= CONTAINER_WRITE
244 def writeCollectionEnd(self
):
245 assert self
.state
== CONTAINER_WRITE
, self
.state
246 self
.state
= self
.__containers
.pop()
247 writeMapEnd
= writeCollectionEnd
248 writeSetEnd
= writeCollectionEnd
249 writeListEnd
= writeCollectionEnd
251 def writeBool(self
, bool):
252 if self
.state
== BOOL_WRITE
:
254 ctype
= CompactType
.TRUE
256 ctype
= CompactType
.FALSE
257 self
.__writeFieldHeader
(ctype
, self
.__bool
_fid
)
258 elif self
.state
== CONTAINER_WRITE
:
260 self
.__writeByte
(CompactType
.TRUE
)
262 self
.__writeByte
(CompactType
.FALSE
)
264 raise AssertionError("Invalid state in compact protocol")
266 writeByte
= writer(__writeByte
)
267 writeI16
= writer(__writeI16
)
270 def writeI32(self
, i32
):
271 self
.__writeVarint
(makeZigZag(i32
, 32))
274 def writeI64(self
, i64
):
275 self
.__writeVarint
(makeZigZag(i64
, 64))
278 def writeDouble(self
, dub
):
279 self
.trans
.write(pack('<d', dub
))
281 def __writeBinary(self
, s
):
282 self
.__writeSize
(len(s
))
284 writeBinary
= writer(__writeBinary
)
286 def readFieldBegin(self
):
287 assert self
.state
== FIELD_READ
, self
.state
288 type = self
.__readUByte
()
289 if type & 0x0f == TType
.STOP
:
293 fid
= self
.__readI
16()
295 fid
= self
.__last
_fid
+ delta
296 self
.__last
_fid
= fid
298 if type == CompactType
.TRUE
:
299 self
.state
= BOOL_READ
300 self
.__bool
_value
= True
301 elif type == CompactType
.FALSE
:
302 self
.state
= BOOL_READ
303 self
.__bool
_value
= False
305 self
.state
= VALUE_READ
306 return (None, self
.__getTType
(type), fid
)
308 def readFieldEnd(self
):
309 assert self
.state
in (VALUE_READ
, BOOL_READ
), self
.state
310 self
.state
= FIELD_READ
312 def __readUByte(self
):
313 result
, = unpack('!B', self
.trans
.readAll(1))
316 def __readByte(self
):
317 result
, = unpack('!b', self
.trans
.readAll(1))
320 def __readVarint(self
):
321 return readVarint(self
.trans
)
323 def __readZigZag(self
):
324 return fromZigZag(self
.__readVarint
())
326 def __readSize(self
):
327 result
= self
.__readVarint
()
329 raise TProtocolException("Length < 0")
332 def readMessageBegin(self
):
333 assert self
.state
== CLEAR
334 proto_id
= self
.__readUByte
()
335 if proto_id
!= self
.PROTOCOL_ID
:
336 raise TProtocolException(TProtocolException
.BAD_VERSION
,
337 'Bad protocol id in the message: %d' % proto_id
)
338 ver_type
= self
.__readUByte
()
339 type = (ver_type
>> self
.TYPE_SHIFT_AMOUNT
) & self
.TYPE_BITS
340 version
= ver_type
& self
.VERSION_MASK
341 if version
!= self
.VERSION
:
342 raise TProtocolException(TProtocolException
.BAD_VERSION
,
343 'Bad version: %d (expect %d)' % (version
, self
.VERSION
))
344 seqid
= self
.__readVarint
()
345 # the sequence is a compact "var int" which is treaded as unsigned,
346 # however the sequence is actually signed...
347 if seqid
> 2147483647:
348 seqid
= -2147483648 - (2147483648 - seqid
)
349 name
= binary_to_str(self
.__readBinary
())
350 return (name
, type, seqid
)
352 def readMessageEnd(self
):
353 assert self
.state
== CLEAR
354 assert len(self
.__structs
) == 0
356 def readStructBegin(self
):
357 assert self
.state
in (CLEAR
, CONTAINER_READ
, VALUE_READ
), self
.state
358 self
.__structs
.append((self
.state
, self
.__last
_fid
))
359 self
.state
= FIELD_READ
362 def readStructEnd(self
):
363 assert self
.state
== FIELD_READ
364 self
.state
, self
.__last
_fid
= self
.__structs
.pop()
366 def readCollectionBegin(self
):
367 assert self
.state
in (VALUE_READ
, CONTAINER_READ
), self
.state
368 size_type
= self
.__readUByte
()
369 size
= size_type
>> 4
370 type = self
.__getTType
(size_type
)
372 size
= self
.__readSize
()
373 self
._check
_container
_length
(size
)
374 self
.__containers
.append(self
.state
)
375 self
.state
= CONTAINER_READ
377 readSetBegin
= readCollectionBegin
378 readListBegin
= readCollectionBegin
380 def readMapBegin(self
):
381 assert self
.state
in (VALUE_READ
, CONTAINER_READ
), self
.state
382 size
= self
.__readSize
()
383 self
._check
_container
_length
(size
)
386 types
= self
.__readUByte
()
387 vtype
= self
.__getTType
(types
)
388 ktype
= self
.__getTType
(types
>> 4)
389 self
.__containers
.append(self
.state
)
390 self
.state
= CONTAINER_READ
391 return (ktype
, vtype
, size
)
393 def readCollectionEnd(self
):
394 assert self
.state
== CONTAINER_READ
, self
.state
395 self
.state
= self
.__containers
.pop()
396 readSetEnd
= readCollectionEnd
397 readListEnd
= readCollectionEnd
398 readMapEnd
= readCollectionEnd
401 if self
.state
== BOOL_READ
:
402 return self
.__bool
_value
== CompactType
.TRUE
403 elif self
.state
== CONTAINER_READ
:
404 return self
.__readByte
() == CompactType
.TRUE
406 raise AssertionError("Invalid state in compact protocol: %d" %
409 readByte
= reader(__readByte
)
410 __readI16
= __readZigZag
411 readI16
= reader(__readZigZag
)
412 readI32
= reader(__readZigZag
)
413 readI64
= reader(__readZigZag
)
416 def readDouble(self
):
417 buff
= self
.trans
.readAll(8)
418 val
, = unpack('<d', buff
)
421 def __readBinary(self
):
422 size
= self
.__readSize
()
423 self
._check
_string
_length
(size
)
424 return self
.trans
.readAll(size
)
425 readBinary
= reader(__readBinary
)
427 def __getTType(self
, byte
):
428 return TTYPES
[byte
& 0x0f]
431 class TCompactProtocolFactory(TProtocolFactory
):
433 string_length_limit
=None,
434 container_length_limit
=None):
435 self
.string_length_limit
= string_length_limit
436 self
.container_length_limit
= container_length_limit
438 def getProtocol(self
, trans
):
439 return TCompactProtocol(trans
,
440 self
.string_length_limit
,
441 self
.container_length_limit
)
444 class TCompactProtocolAccelerated(TCompactProtocol
):
445 """C-Accelerated version of TCompactProtocol.
447 This class does not override any of TCompactProtocol's methods,
448 but the generated code recognizes it directly and will call into
449 our C module to do the encoding, bypassing this object entirely.
450 We inherit from TCompactProtocol so that the normal TCompactProtocol
451 encoding can happen if the fastbinary module doesn't work for some
453 To disable this behavior, pass fallback=False constructor argument.
455 In order to take advantage of the C module, just use
456 TCompactProtocolAccelerated instead of TCompactProtocol.
460 def __init__(self
, *args
, **kwargs
):
461 fallback
= kwargs
.pop('fallback', True)
462 super(TCompactProtocolAccelerated
, self
).__init
__(*args
, **kwargs
)
464 from thrift
.protocol
import fastbinary
469 self
._fast
_decode
= fastbinary
.decode_compact
470 self
._fast
_encode
= fastbinary
.encode_compact
473 class TCompactProtocolAcceleratedFactory(TProtocolFactory
):
475 string_length_limit
=None,
476 container_length_limit
=None,
478 self
.string_length_limit
= string_length_limit
479 self
.container_length_limit
= container_length_limit
480 self
._fallback
= fallback
482 def getProtocol(self
, trans
):
483 return TCompactProtocolAccelerated(
485 string_length_limit
=self
.string_length_limit
,
486 container_length_limit
=self
.container_length_limit
,
487 fallback
=self
._fallback
)