]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | # |
2 | # Licensed to the Apache Software Foundation (ASF) under one | |
3 | # or more contributor license agreements. See the NOTICE file | |
4 | # distributed with this work for additional information | |
5 | # regarding copyright ownership. The ASF licenses this file | |
6 | # to you under the Apache License, Version 2.0 (the | |
7 | # "License"); you may not use this file except in compliance | |
8 | # with the License. You may obtain a copy of the License at | |
9 | # | |
10 | # http://www.apache.org/licenses/LICENSE-2.0 | |
11 | # | |
12 | # Unless required by applicable law or agreed to in writing, | |
13 | # software distributed under the License is distributed on an | |
14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | # KIND, either express or implied. See the License for the | |
16 | # specific language governing permissions and limitations | |
17 | # under the License. | |
18 | # | |
19 | ||
20 | from thrift.Thrift import TProcessor, TMessageType | |
21 | from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol | |
22 | from thrift.protocol.TProtocol import TProtocolException | |
23 | ||
24 | ||
25 | class TMultiplexedProcessor(TProcessor): | |
26 | def __init__(self): | |
27 | self.defaultProcessor = None | |
28 | self.services = {} | |
29 | ||
30 | def registerDefault(self, processor): | |
31 | """ | |
32 | If a non-multiplexed processor connects to the server and wants to | |
33 | communicate, use the given processor to handle it. This mechanism | |
34 | allows servers to upgrade from non-multiplexed to multiplexed in a | |
35 | backwards-compatible way and still handle old clients. | |
36 | """ | |
37 | self.defaultProcessor = processor | |
38 | ||
39 | def registerProcessor(self, serviceName, processor): | |
40 | self.services[serviceName] = processor | |
41 | ||
42 | def on_message_begin(self, func): | |
43 | for key in self.services.keys(): | |
44 | self.services[key].on_message_begin(func) | |
45 | ||
46 | def process(self, iprot, oprot): | |
47 | (name, type, seqid) = iprot.readMessageBegin() | |
48 | if type != TMessageType.CALL and type != TMessageType.ONEWAY: | |
49 | raise TProtocolException( | |
50 | TProtocolException.NOT_IMPLEMENTED, | |
51 | "TMultiplexedProtocol only supports CALL & ONEWAY") | |
52 | ||
53 | index = name.find(TMultiplexedProtocol.SEPARATOR) | |
54 | if index < 0: | |
55 | if self.defaultProcessor: | |
56 | return self.defaultProcessor.process( | |
57 | StoredMessageProtocol(iprot, (name, type, seqid)), oprot) | |
58 | else: | |
59 | raise TProtocolException( | |
60 | TProtocolException.NOT_IMPLEMENTED, | |
61 | "Service name not found in message name: " + name + ". " + | |
62 | "Did you forget to use TMultiplexedProtocol in your client?") | |
63 | ||
64 | serviceName = name[0:index] | |
65 | call = name[index + len(TMultiplexedProtocol.SEPARATOR):] | |
66 | if serviceName not in self.services: | |
67 | raise TProtocolException( | |
68 | TProtocolException.NOT_IMPLEMENTED, | |
69 | "Service name not found: " + serviceName + ". " + | |
70 | "Did you forget to call registerProcessor()?") | |
71 | ||
72 | standardMessage = (call, type, seqid) | |
73 | return self.services[serviceName].process( | |
74 | StoredMessageProtocol(iprot, standardMessage), oprot) | |
75 | ||
76 | ||
77 | class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): | |
78 | def __init__(self, protocol, messageBegin): | |
79 | self.messageBegin = messageBegin | |
80 | ||
81 | def readMessageBegin(self): | |
82 | return self.messageBegin |