]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | -- |
2 | -- Licensed to the Apache Software Foundation (ASF) under one | |
3 | -- or more contributor license agreements. See the NOTICE file | |
4 | -- distributed with this work for additional information | |
5 | -- regarding copyright ownership. The ASF licenses this file | |
6 | -- to you under the Apache License, Version 2.0 (the | |
7 | -- "License"); you may not use this file except in compliance | |
8 | -- with the License. You may obtain a copy of the License at | |
9 | -- | |
10 | -- http://www.apache.org/licenses/LICENSE-2.0 | |
11 | -- | |
12 | -- Unless required by applicable law or agreed to in writing, | |
13 | -- software distributed under the License is distributed on an | |
14 | -- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | -- KIND, either express or implied. See the License for the | |
16 | -- specific language governing permissions and limitations | |
17 | -- under the License. | |
18 | -- | |
19 | ||
20 | module Thrift.Transport.Header | |
21 | ( module Thrift.Transport | |
22 | , HeaderTransport(..) | |
23 | , openHeaderTransport | |
24 | , ProtocolType(..) | |
25 | , TransformType(..) | |
26 | , ClientType(..) | |
27 | , tResetProtocol | |
28 | , tSetProtocol | |
29 | ) where | |
30 | ||
31 | import Thrift.Transport | |
32 | import Thrift.Protocol.Compact | |
33 | import Control.Applicative | |
34 | import Control.Exception ( throw ) | |
35 | import Control.Monad | |
36 | import Data.Bits | |
37 | import Data.IORef | |
38 | import Data.Int | |
39 | import Data.Monoid | |
40 | import Data.Word | |
41 | ||
42 | import qualified Data.Attoparsec.ByteString as P | |
43 | import qualified Data.Binary as Binary | |
44 | import qualified Data.ByteString as BS | |
45 | import qualified Data.ByteString.Char8 as C | |
46 | import qualified Data.ByteString.Lazy as LBS | |
47 | import qualified Data.ByteString.Lazy.Builder as B | |
48 | import qualified Data.Map as Map | |
49 | ||
50 | data ProtocolType = TBinary | TCompact | TJSON deriving (Enum, Eq) | |
51 | data ClientType = HeaderClient | Framed | Unframed deriving (Enum, Eq) | |
52 | ||
53 | infoIdKeyValue = 1 | |
54 | ||
55 | type Headers = Map.Map String String | |
56 | ||
57 | data TransformType = ZlibTransform deriving (Enum, Eq) | |
58 | ||
59 | fromTransportType :: TransformType -> Int16 | |
60 | fromTransportType ZlibTransform = 1 | |
61 | ||
62 | toTransportType :: Int16 -> TransformType | |
63 | toTransportType 1 = ZlibTransform | |
64 | toTransportType _ = throw $ TransportExn "HeaderTransport: Unknown transform ID" TE_UNKNOWN | |
65 | ||
66 | data HeaderTransport i o = (Transport i, Transport o) => HeaderTransport | |
67 | { readBuffer :: IORef LBS.ByteString | |
68 | , writeBuffer :: IORef B.Builder | |
69 | , inTrans :: i | |
70 | , outTrans :: o | |
71 | , clientType :: IORef ClientType | |
72 | , protocolType :: IORef ProtocolType | |
73 | , headers :: IORef [(String, String)] | |
74 | , writeHeaders :: Headers | |
75 | , transforms :: IORef [TransformType] | |
76 | , writeTransforms :: [TransformType] | |
77 | } | |
78 | ||
79 | openHeaderTransport :: (Transport i, Transport o) => i -> o -> IO (HeaderTransport i o) | |
80 | openHeaderTransport i o = do | |
81 | pid <- newIORef TCompact | |
82 | rBuf <- newIORef LBS.empty | |
83 | wBuf <- newIORef mempty | |
84 | cType <- newIORef HeaderClient | |
85 | h <- newIORef [] | |
86 | trans <- newIORef [] | |
87 | return HeaderTransport | |
88 | { readBuffer = rBuf | |
89 | , writeBuffer = wBuf | |
90 | , inTrans = i | |
91 | , outTrans = o | |
92 | , clientType = cType | |
93 | , protocolType = pid | |
94 | , headers = h | |
95 | , writeHeaders = Map.empty | |
96 | , transforms = trans | |
97 | , writeTransforms = [] | |
98 | } | |
99 | ||
100 | isFramed t = (/= Unframed) <$> readIORef (clientType t) | |
101 | ||
102 | readFrame :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool | |
103 | readFrame t = do | |
104 | let input = inTrans t | |
105 | let rBuf = readBuffer t | |
106 | let cType = clientType t | |
107 | lsz <- tRead input 4 | |
108 | let sz = LBS.toStrict lsz | |
109 | case P.parseOnly P.endOfInput sz of | |
110 | Right _ -> do return False | |
111 | Left _ -> do | |
112 | case parseBinaryMagic sz of | |
113 | Right _ -> do | |
114 | writeIORef rBuf $ lsz | |
115 | writeIORef cType Unframed | |
116 | writeIORef (protocolType t) TBinary | |
117 | return True | |
118 | Left _ -> do | |
119 | case parseCompactMagic sz of | |
120 | Right _ -> do | |
121 | writeIORef rBuf $ lsz | |
122 | writeIORef cType Unframed | |
123 | writeIORef (protocolType t) TCompact | |
124 | return True | |
125 | Left _ -> do | |
126 | let len = Binary.decode lsz :: Int32 | |
127 | lbuf <- tReadAll input $ fromIntegral len | |
128 | let buf = LBS.toStrict lbuf | |
129 | case parseBinaryMagic buf of | |
130 | Right _ -> do | |
131 | writeIORef cType Framed | |
132 | writeIORef (protocolType t) TBinary | |
133 | writeIORef rBuf lbuf | |
134 | return True | |
135 | Left _ -> do | |
136 | case parseCompactMagic buf of | |
137 | Right _ -> do | |
138 | writeIORef cType Framed | |
139 | writeIORef (protocolType t) TCompact | |
140 | writeIORef rBuf lbuf | |
141 | return True | |
142 | Left _ -> do | |
143 | case parseHeaderMagic buf of | |
144 | Right flags -> do | |
145 | let (flags, seqNum, header, body) = extractHeader buf | |
146 | writeIORef cType HeaderClient | |
147 | handleHeader t header | |
148 | payload <- untransform t body | |
149 | writeIORef rBuf $ LBS.fromStrict $ payload | |
150 | return True | |
151 | Left _ -> | |
152 | throw $ TransportExn "HeaderTransport: unkonwn client type" TE_UNKNOWN | |
153 | ||
154 | parseBinaryMagic = P.parseOnly $ P.word8 0x80 *> P.word8 0x01 *> P.word8 0x00 *> P.anyWord8 | |
155 | parseCompactMagic = P.parseOnly $ P.word8 0x82 *> P.satisfy (\b -> b .&. 0x1f == 0x01) | |
156 | parseHeaderMagic = P.parseOnly $ P.word8 0x0f *> P.word8 0xff *> (P.count 2 P.anyWord8) | |
157 | ||
158 | parseI32 :: P.Parser Int32 | |
159 | parseI32 = Binary.decode . LBS.fromStrict <$> P.take 4 | |
160 | parseI16 :: P.Parser Int16 | |
161 | parseI16 = Binary.decode . LBS.fromStrict <$> P.take 2 | |
162 | ||
163 | extractHeader :: BS.ByteString -> (Int16, Int32, BS.ByteString, BS.ByteString) | |
164 | extractHeader bs = | |
165 | case P.parse extractHeader_ bs of | |
166 | P.Done remain (flags, seqNum, header) -> (flags, seqNum, header, remain) | |
167 | _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN | |
168 | where | |
169 | extractHeader_ = do | |
170 | magic <- P.word8 0x0f *> P.word8 0xff | |
171 | flags <- parseI16 | |
172 | seqNum <- parseI32 | |
173 | (headerSize :: Int) <- (* 4) . fromIntegral <$> parseI16 | |
174 | header <- P.take headerSize | |
175 | return (flags, seqNum, header) | |
176 | ||
177 | handleHeader t header = | |
178 | case P.parseOnly parseHeader header of | |
179 | Right (pType, trans, info) -> do | |
180 | writeIORef (protocolType t) pType | |
181 | writeIORef (transforms t) trans | |
182 | writeIORef (headers t) info | |
183 | _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN | |
184 | ||
185 | ||
186 | iw16 :: Int16 -> Word16 | |
187 | iw16 = fromIntegral | |
188 | iw32 :: Int32 -> Word32 | |
189 | iw32 = fromIntegral | |
190 | wi16 :: Word16 -> Int16 | |
191 | wi16 = fromIntegral | |
192 | wi32 :: Word32 -> Int32 | |
193 | wi32 = fromIntegral | |
194 | ||
195 | parseHeader :: P.Parser (ProtocolType, [TransformType], [(String, String)]) | |
196 | parseHeader = do | |
197 | protocolType <- toProtocolType <$> parseVarint wi16 | |
198 | numTrans <- fromIntegral <$> parseVarint wi16 | |
199 | trans <- replicateM numTrans parseTransform | |
200 | info <- parseInfo | |
201 | return (protocolType, trans, info) | |
202 | ||
203 | toProtocolType :: Int16 -> ProtocolType | |
204 | toProtocolType 0 = TBinary | |
205 | toProtocolType 1 = TJSON | |
206 | toProtocolType 2 = TCompact | |
207 | ||
208 | fromProtocolType :: ProtocolType -> Int16 | |
209 | fromProtocolType TBinary = 0 | |
210 | fromProtocolType TJSON = 1 | |
211 | fromProtocolType TCompact = 2 | |
212 | ||
213 | parseTransform :: P.Parser TransformType | |
214 | parseTransform = toTransportType <$> parseVarint wi16 | |
215 | ||
216 | parseInfo :: P.Parser [(String, String)] | |
217 | parseInfo = do | |
218 | n <- P.eitherP P.endOfInput (parseVarint wi32) | |
219 | case n of | |
220 | Left _ -> return [] | |
221 | Right n0 -> | |
222 | replicateM (fromIntegral n0) $ do | |
223 | klen <- parseVarint wi16 | |
224 | k <- P.take $ fromIntegral klen | |
225 | vlen <- parseVarint wi16 | |
226 | v <- P.take $ fromIntegral vlen | |
227 | return (C.unpack k, C.unpack v) | |
228 | ||
229 | parseString :: P.Parser BS.ByteString | |
230 | parseString = parseVarint wi32 >>= (P.take . fromIntegral) | |
231 | ||
232 | buildHeader :: HeaderTransport i o -> IO B.Builder | |
233 | buildHeader t = do | |
234 | pType <- readIORef $ protocolType t | |
235 | let pId = buildVarint $ iw16 $ fromProtocolType pType | |
236 | let headerContent = pId <> (buildTransforms t) <> (buildInfo t) | |
237 | let len = fromIntegral $ LBS.length $ B.toLazyByteString headerContent | |
238 | -- TODO: length limit check | |
239 | let padding = mconcat $ replicate (mod len 4) $ B.word8 0 | |
240 | let codedLen = B.int16BE (fromIntegral $ (quot (len - 1) 4) + 1) | |
241 | let flags = 0 | |
242 | let seqNum = 0 | |
243 | return $ B.int16BE 0x0fff <> B.int16BE flags <> B.int32BE seqNum <> codedLen <> headerContent <> padding | |
244 | ||
245 | buildTransforms :: HeaderTransport i o -> B.Builder | |
246 | -- TODO: check length limit | |
247 | buildTransforms t = | |
248 | let trans = writeTransforms t in | |
249 | (buildVarint $ iw16 $ fromIntegral $ length trans) <> | |
250 | (mconcat $ map (buildVarint . iw16 . fromTransportType) trans) | |
251 | ||
252 | buildInfo :: HeaderTransport i o -> B.Builder | |
253 | buildInfo t = | |
254 | let h = Map.assocs $ writeHeaders t in | |
255 | -- TODO: check length limit | |
256 | case length h of | |
257 | 0 -> mempty | |
258 | len -> (buildVarint $ iw16 $ fromIntegral $ len) <> (mconcat $ map buildInfoEntry h) | |
259 | where | |
260 | buildInfoEntry (k, v) = buildVarStr k <> buildVarStr v | |
261 | -- TODO: check length limit | |
262 | buildVarStr s = (buildVarint $ iw16 $ fromIntegral $ length s) <> B.string8 s | |
263 | ||
264 | tResetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool | |
265 | tResetProtocol t = do | |
266 | rBuf <- readIORef $ readBuffer t | |
267 | writeIORef (clientType t) HeaderClient | |
268 | readFrame t | |
269 | ||
270 | tSetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> ProtocolType -> IO () | |
271 | tSetProtocol t = writeIORef (protocolType t) | |
272 | ||
273 | transform :: HeaderTransport i o -> LBS.ByteString -> LBS.ByteString | |
274 | transform t bs = | |
275 | foldr applyTransform bs $ writeTransforms t | |
276 | where | |
277 | -- applyTransform bs ZlibTransform = | |
278 | -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN | |
279 | applyTransform bs _ = | |
280 | throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN | |
281 | ||
282 | untransform :: HeaderTransport i o -> BS.ByteString -> IO BS.ByteString | |
283 | untransform t bs = do | |
284 | trans <- readIORef $ transforms t | |
285 | return $ foldl unapplyTransform bs trans | |
286 | where | |
287 | -- unapplyTransform bs ZlibTransform = | |
288 | -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN | |
289 | unapplyTransform bs _ = | |
290 | throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN | |
291 | ||
292 | instance (Transport i, Transport o) => Transport (HeaderTransport i o) where | |
293 | tIsOpen t = do | |
294 | tIsOpen (inTrans t) | |
295 | tIsOpen (outTrans t) | |
296 | ||
297 | tClose t = do | |
298 | tClose(outTrans t) | |
299 | tClose(inTrans t) | |
300 | ||
301 | tRead t len = do | |
302 | rBuf <- readIORef $ readBuffer t | |
303 | if not $ LBS.null rBuf | |
304 | then do | |
305 | let (consumed, remain) = LBS.splitAt (fromIntegral len) rBuf | |
306 | writeIORef (readBuffer t) remain | |
307 | return consumed | |
308 | else do | |
309 | framed <- isFramed t | |
310 | if not framed | |
311 | then tRead (inTrans t) len | |
312 | else do | |
313 | ok <- readFrame t | |
314 | if ok | |
315 | then tRead t len | |
316 | else return LBS.empty | |
317 | ||
318 | tPeek t = do | |
319 | rBuf <- readIORef (readBuffer t) | |
320 | if not $ LBS.null rBuf | |
321 | then return $ Just $ LBS.head rBuf | |
322 | else do | |
323 | framed <- isFramed t | |
324 | if not framed | |
325 | then tPeek (inTrans t) | |
326 | else do | |
327 | ok <- readFrame t | |
328 | if ok | |
329 | then tPeek t | |
330 | else return Nothing | |
331 | ||
332 | tWrite t buf = do | |
333 | let wBuf = writeBuffer t | |
334 | framed <- isFramed t | |
335 | if framed | |
336 | then modifyIORef wBuf (<> B.lazyByteString buf) | |
337 | else | |
338 | -- TODO: what should we do when switched to unframed in the middle ? | |
339 | tWrite(outTrans t) buf | |
340 | ||
341 | tFlush t = do | |
342 | cType <- readIORef $ clientType t | |
343 | case cType of | |
344 | Unframed -> tFlush $ outTrans t | |
345 | Framed -> flushBuffer t id mempty | |
346 | HeaderClient -> buildHeader t >>= flushBuffer t (transform t) | |
347 | where | |
348 | flushBuffer t f header = do | |
349 | wBuf <- readIORef $ writeBuffer t | |
350 | writeIORef (writeBuffer t) mempty | |
351 | let payload = B.toLazyByteString (header <> wBuf) | |
352 | tWrite (outTrans t) $ Binary.encode (fromIntegral $ LBS.length payload :: Int32) | |
353 | tWrite (outTrans t) $ f payload | |
354 | tFlush (outTrans t) |