--- /dev/null
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from six.moves import queue
+import logging
+import os
+import threading
+
+from thrift.protocol import TBinaryProtocol
+from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
+from thrift.transport import TTransport
+
+logger = logging.getLogger(__name__)
+
+
+class TServer(object):
+ """Base interface for a server, which must have a serve() method.
+
+ Three constructors for all servers:
+ 1) (processor, serverTransport)
+ 2) (processor, serverTransport, transportFactory, protocolFactory)
+ 3) (processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory)
+ """
+ def __init__(self, *args):
+ if (len(args) == 2):
+ self.__initArgs__(args[0], args[1],
+ TTransport.TTransportFactoryBase(),
+ TTransport.TTransportFactoryBase(),
+ TBinaryProtocol.TBinaryProtocolFactory(),
+ TBinaryProtocol.TBinaryProtocolFactory())
+ elif (len(args) == 4):
+ self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
+ elif (len(args) == 6):
+ self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
+
+ def __initArgs__(self, processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory):
+ self.processor = processor
+ self.serverTransport = serverTransport
+ self.inputTransportFactory = inputTransportFactory
+ self.outputTransportFactory = outputTransportFactory
+ self.inputProtocolFactory = inputProtocolFactory
+ self.outputProtocolFactory = outputProtocolFactory
+
+ input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
+ output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
+ if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
+ raise ValueError("THeaderProtocol servers require that both the input and "
+ "output protocols are THeaderProtocol.")
+
+ def serve(self):
+ pass
+
+
+class TSimpleServer(TServer):
+ """Simple single-threaded server that just pumps around one transport."""
+
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
+
+ def serve(self):
+ self.serverTransport.listen()
+ while True:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+
+ itrans = self.inputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for
+ # input and output so that the response is in the same dialect that
+ # the server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
+
+ itrans.close()
+ if otrans:
+ otrans.close()
+
+
+class TThreadedServer(TServer):
+ """Threaded server that spawns a new thread per each connection."""
+
+ def __init__(self, *args, **kwargs):
+ TServer.__init__(self, *args)
+ self.daemon = kwargs.get("daemon", False)
+
+ def serve(self):
+ self.serverTransport.listen()
+ while True:
+ try:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ t = threading.Thread(target=self.handle, args=(client,))
+ t.setDaemon(self.daemon)
+ t.start()
+ except KeyboardInterrupt:
+ raise
+ except Exception as x:
+ logger.exception(x)
+
+ def handle(self, client):
+ itrans = self.inputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for input
+ # and output so that the response is in the same dialect that the
+ # server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
+
+ itrans.close()
+ if otrans:
+ otrans.close()
+
+
+class TThreadPoolServer(TServer):
+ """Server with a fixed size pool of threads which service requests."""
+
+ def __init__(self, *args, **kwargs):
+ TServer.__init__(self, *args)
+ self.clients = queue.Queue()
+ self.threads = 10
+ self.daemon = kwargs.get("daemon", False)
+
+ def setNumThreads(self, num):
+ """Set the number of worker threads that should be created"""
+ self.threads = num
+
+ def serveThread(self):
+ """Loop around getting clients from the shared queue and process them."""
+ while True:
+ try:
+ client = self.clients.get()
+ self.serveClient(client)
+ except Exception as x:
+ logger.exception(x)
+
+ def serveClient(self, client):
+ """Process input/output from a client for as long as possible"""
+ itrans = self.inputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for input
+ # and output so that the response is in the same dialect that the
+ # server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
+
+ itrans.close()
+ if otrans:
+ otrans.close()
+
+ def serve(self):
+ """Start a fixed number of worker threads and put client into a queue"""
+ for i in range(self.threads):
+ try:
+ t = threading.Thread(target=self.serveThread)
+ t.setDaemon(self.daemon)
+ t.start()
+ except Exception as x:
+ logger.exception(x)
+
+ # Pump the socket for clients
+ self.serverTransport.listen()
+ while True:
+ try:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ self.clients.put(client)
+ except Exception as x:
+ logger.exception(x)
+
+
+class TForkingServer(TServer):
+ """A Thrift server that forks a new process for each request
+
+ This is more scalable than the threaded server as it does not cause
+ GIL contention.
+
+ Note that this has different semantics from the threading server.
+ Specifically, updates to shared variables will no longer be shared.
+ It will also not work on windows.
+
+ This code is heavily inspired by SocketServer.ForkingMixIn in the
+ Python stdlib.
+ """
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
+ self.children = []
+
+ def serve(self):
+ def try_close(file):
+ try:
+ file.close()
+ except IOError as e:
+ logger.warning(e, exc_info=True)
+
+ self.serverTransport.listen()
+ while True:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ try:
+ pid = os.fork()
+
+ if pid: # parent
+ # add before collect, otherwise you race w/ waitpid
+ self.children.append(pid)
+ self.collect_children()
+
+ # Parent must close socket or the connection may not get
+ # closed promptly
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ try_close(itrans)
+ try_close(otrans)
+ else:
+ itrans = self.inputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+
+ # for THeaderProtocol, we must use the same protocol
+ # instance for input and output so that the response is in
+ # the same dialect that the server detected the request was
+ # in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ ecode = 0
+ try:
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as e:
+ logger.exception(e)
+ ecode = 1
+ finally:
+ try_close(itrans)
+ if otrans:
+ try_close(otrans)
+
+ os._exit(ecode)
+
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
+
+ def collect_children(self):
+ while self.children:
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except os.error:
+ pid = None
+
+ if pid:
+ self.children.remove(pid)
+ else:
+ break