+++ /dev/null
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-import ssl
-
-from six.moves import BaseHTTPServer
-
-from thrift.Thrift import TMessageType
-from thrift.server import TServer
-from thrift.transport import TTransport
-
-
-class ResponseException(Exception):
- """Allows handlers to override the HTTP response
-
- Normally, THttpServer always sends a 200 response. If a handler wants
- to override this behavior (e.g., to simulate a misconfigured or
- overloaded web server during testing), it can raise a ResponseException.
- The function passed to the constructor will be called with the
- RequestHandler as its only argument. Note that this is irrelevant
- for ONEWAY requests, as the HTTP response must be sent before the
- RPC is processed.
- """
- def __init__(self, handler):
- self.handler = handler
-
-
-class THttpServer(TServer.TServer):
- """A simple HTTP-based Thrift server
-
- This class is not very performant, but it is useful (for example) for
- acting as a mock version of an Apache-based PHP Thrift endpoint.
- Also important to note the HTTP implementation pretty much violates the
- transport/protocol/processor/server layering, by performing the transport
- functions here. This means things like oneway handling are oddly exposed.
- """
- def __init__(self,
- processor,
- server_address,
- inputProtocolFactory,
- outputProtocolFactory=None,
- server_class=BaseHTTPServer.HTTPServer,
- **kwargs):
- """Set up protocol factories and HTTP (or HTTPS) server.
-
- See BaseHTTPServer for server_address.
- See TServer for protocol factories.
-
- To make a secure server, provide the named arguments:
- * cafile - to validate clients [optional]
- * cert_file - the server cert
- * key_file - the server's key
- """
- if outputProtocolFactory is None:
- outputProtocolFactory = inputProtocolFactory
-
- TServer.TServer.__init__(self, processor, None, None, None,
- inputProtocolFactory, outputProtocolFactory)
-
- thttpserver = self
- self._replied = None
-
- class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
- def do_POST(self):
- # Don't care about the request path.
- thttpserver._replied = False
- iftrans = TTransport.TFileObjectTransport(self.rfile)
- itrans = TTransport.TBufferedTransport(
- iftrans, int(self.headers['Content-Length']))
- otrans = TTransport.TMemoryBuffer()
- iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
- oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
- try:
- thttpserver.processor.on_message_begin(self.on_begin)
- thttpserver.processor.process(iprot, oprot)
- except ResponseException as exn:
- exn.handler(self)
- else:
- if not thttpserver._replied:
- # If the request was ONEWAY we would have replied already
- data = otrans.getvalue()
- self.send_response(200)
- self.send_header("Content-Length", len(data))
- self.send_header("Content-Type", "application/x-thrift")
- self.end_headers()
- self.wfile.write(data)
-
- def on_begin(self, name, type, seqid):
- """
- Inspect the message header.
-
- This allows us to post an immediate transport response
- if the request is a ONEWAY message type.
- """
- if type == TMessageType.ONEWAY:
- self.send_response(200)
- self.send_header("Content-Type", "application/x-thrift")
- self.end_headers()
- thttpserver._replied = True
-
- self.httpd = server_class(server_address, RequestHander)
-
- if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
- context = ssl.create_default_context(cafile=kwargs.get('cafile'))
- context.check_hostname = False
- context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
- context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
- self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
-
- def serve(self):
- self.httpd.serve_forever()
-
- def shutdown(self):
- self.httpd.socket.close()
- # self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly!