2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
10 * http://www.apache.org/licenses/LICENSE-2.0
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
24 public class THTTPSessionTransport: TAsyncTransport {
25 public class Factory : TAsyncTransportFactory {
26 public var responseValidate: ((HTTPURLResponse?, Data?) throws -> Void)?
28 var session: URLSession
31 public class func setupDefaultsForSessionConfiguration(_ config: URLSessionConfiguration, withProtocolName protocolName: String?) {
32 var thriftContentType = "application/x-thrift"
34 if let protocolName = protocolName {
35 thriftContentType += "; p=\(protocolName)"
38 config.requestCachePolicy = .reloadIgnoringLocalCacheData
41 config.httpShouldUsePipelining = true
42 config.httpShouldSetCookies = true
43 config.httpAdditionalHeaders = ["Content-Type": thriftContentType,
44 "Accept": thriftContentType,
45 "User-Agent": "Thrift/Swift (Session)"]
50 public init(session: URLSession, url: URL) {
51 self.session = session
55 public func newTransport() -> THTTPSessionTransport {
56 return THTTPSessionTransport(factory: self)
59 func validateResponse(_ response: HTTPURLResponse?, data: Data?) throws {
60 try responseValidate?(response, data)
63 func taskWithRequest(_ request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> ()) throws -> URLSessionTask {
65 let newTask: URLSessionTask? = session.dataTask(with: request, completionHandler: completionHandler)
66 if let newTask = newTask {
69 throw TTransportError(error: .unknown, message: "Failed to create session data task")
75 var requestData = Data()
76 var responseData = Data()
77 var responseDataOffset: Int = 0
79 init(factory: Factory) {
80 self.factory = factory
83 public func readAll(size: Int) throws -> Data {
84 let read = try self.read(size: size)
85 if read.count != size {
86 throw TTransportError(error: .endOfFile)
91 public func read(size: Int) throws -> Data {
92 let avail = responseData.count - responseDataOffset
93 let (start, stop) = (responseDataOffset, responseDataOffset + min(size, avail))
94 let read = responseData.subdata(in: start..<stop)
95 responseDataOffset += read.count
99 public func write(data: Data) throws {
100 requestData.append(data)
103 public func flush(_ completed: @escaping (TAsyncTransport, Error?) -> Void) {
105 var task: URLSessionTask?
107 var request = URLRequest(url: factory.url)
108 request.httpMethod = "POST"
109 request.httpBody = requestData
114 task = try factory.taskWithRequest(request, completionHandler: { (data, response, taskError) in
116 // Check if there was an error with the network
117 if taskError != nil {
118 error = TTransportError(error: .timedOut)
119 completed(self, error)
123 // Check response type
124 if taskError == nil && !(response is HTTPURLResponse) {
125 error = THTTPTransportError(error: .invalidResponse)
126 completed(self, error)
131 if let httpResponse = response as? HTTPURLResponse {
132 if taskError == nil && httpResponse.statusCode != 200 {
133 if httpResponse.statusCode == 401 {
134 error = THTTPTransportError(error: .authentication)
136 error = THTTPTransportError(error: .invalidStatus(statusCode: httpResponse.statusCode))
140 // Allow factory to check
143 try self.factory.validateResponse(httpResponse, data: data)
144 } catch let validateError {
145 error = validateError
149 self.responseDataOffset = 0
151 self.responseData = Data()
153 self.responseData = data ?? Data()
155 completed(self, error)
159 } catch let taskError {
163 if let error = error, task == nil {
164 completed(self, error)
169 public func flush() throws {
170 let completed = DispatchSemaphore(value: 0)
171 var internalError: Error?
173 flush() { _, error in
174 internalError = error
178 _ = completed.wait(timeout: DispatchTime.distantFuture)
180 if let error = internalError {