2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.arrow
.flight
.example
.integration
;
20 import java
.nio
.charset
.StandardCharsets
;
21 import java
.util
.Arrays
;
22 import java
.util
.Collections
;
24 import org
.apache
.arrow
.flight
.CallHeaders
;
25 import org
.apache
.arrow
.flight
.CallInfo
;
26 import org
.apache
.arrow
.flight
.CallStatus
;
27 import org
.apache
.arrow
.flight
.FlightClient
;
28 import org
.apache
.arrow
.flight
.FlightClientMiddleware
;
29 import org
.apache
.arrow
.flight
.FlightDescriptor
;
30 import org
.apache
.arrow
.flight
.FlightInfo
;
31 import org
.apache
.arrow
.flight
.FlightProducer
;
32 import org
.apache
.arrow
.flight
.FlightRuntimeException
;
33 import org
.apache
.arrow
.flight
.FlightServer
;
34 import org
.apache
.arrow
.flight
.FlightServerMiddleware
;
35 import org
.apache
.arrow
.flight
.Location
;
36 import org
.apache
.arrow
.flight
.NoOpFlightProducer
;
37 import org
.apache
.arrow
.flight
.RequestContext
;
38 import org
.apache
.arrow
.memory
.BufferAllocator
;
39 import org
.apache
.arrow
.vector
.types
.pojo
.Schema
;
42 * Test an edge case in middleware: gRPC-Java consolidates headers and trailers if a call fails immediately. On the
43 * gRPC implementation side, we need to watch for this, or else we'll have a call with "no headers" if we only look
46 final class MiddlewareScenario
implements Scenario
{
48 private static final String HEADER
= "x-middleware";
49 private static final String EXPECTED_HEADER_VALUE
= "expected value";
50 private static final byte[] COMMAND_SUCCESS
= "success".getBytes(StandardCharsets
.UTF_8
);
53 public FlightProducer
producer(BufferAllocator allocator
, Location location
) {
54 return new NoOpFlightProducer() {
56 public FlightInfo
getFlightInfo(CallContext context
, FlightDescriptor descriptor
) {
57 if (descriptor
.isCommand()) {
58 if (Arrays
.equals(COMMAND_SUCCESS
, descriptor
.getCommand())) {
59 return new FlightInfo(new Schema(Collections
.emptyList()), descriptor
, Collections
.emptyList(), -1, -1);
62 throw CallStatus
.UNIMPLEMENTED
.toRuntimeException();
68 public void buildServer(FlightServer
.Builder builder
) {
69 builder
.middleware(FlightServerMiddleware
.Key
.of("test"), new InjectingServerMiddleware
.Factory());
73 public void client(BufferAllocator allocator
, Location location
, FlightClient ignored
) throws Exception
{
74 final ExtractingClientMiddleware
.Factory factory
= new ExtractingClientMiddleware
.Factory();
75 try (final FlightClient client
= FlightClient
.builder(allocator
, location
).intercept(factory
).build()) {
76 // Should fail immediately
77 IntegrationAssertions
.assertThrows(FlightRuntimeException
.class,
78 () -> client
.getInfo(FlightDescriptor
.command(new byte[0])));
79 if (!EXPECTED_HEADER_VALUE
.equals(factory
.extractedHeader
)) {
80 throw new AssertionError(
81 "Expected to extract the header value '" +
82 EXPECTED_HEADER_VALUE
+
84 factory
.extractedHeader
);
88 factory
.extractedHeader
= "";
89 client
.getInfo(FlightDescriptor
.command(COMMAND_SUCCESS
));
90 if (!EXPECTED_HEADER_VALUE
.equals(factory
.extractedHeader
)) {
91 throw new AssertionError(
92 "Expected to extract the header value '" +
93 EXPECTED_HEADER_VALUE
+
95 factory
.extractedHeader
);
100 /** Middleware that inserts a constant value in outgoing requests. */
101 static class InjectingServerMiddleware
implements FlightServerMiddleware
{
103 private final String headerValue
;
105 InjectingServerMiddleware(String incoming
) {
106 this.headerValue
= incoming
;
110 public void onBeforeSendingHeaders(CallHeaders outgoingHeaders
) {
111 outgoingHeaders
.insert("x-middleware", headerValue
);
115 public void onCallCompleted(CallStatus status
) {
119 public void onCallErrored(Throwable err
) {
122 /** The factory for the server middleware. */
123 static class Factory
implements FlightServerMiddleware
.Factory
<InjectingServerMiddleware
> {
126 public InjectingServerMiddleware
onCallStarted(CallInfo info
, CallHeaders incomingHeaders
,
127 RequestContext context
) {
128 String incoming
= incomingHeaders
.get(HEADER
);
129 return new InjectingServerMiddleware(incoming
== null ?
"" : incoming
);
134 /** Middleware that pulls a value out of incoming responses. */
135 static class ExtractingClientMiddleware
implements FlightClientMiddleware
{
137 private final ExtractingClientMiddleware
.Factory factory
;
139 public ExtractingClientMiddleware(ExtractingClientMiddleware
.Factory factory
) {
140 this.factory
= factory
;
144 public void onBeforeSendingHeaders(CallHeaders outgoingHeaders
) {
145 outgoingHeaders
.insert(HEADER
, EXPECTED_HEADER_VALUE
);
149 public void onHeadersReceived(CallHeaders incomingHeaders
) {
150 this.factory
.extractedHeader
= incomingHeaders
.get(HEADER
);
154 public void onCallCompleted(CallStatus status
) {
157 /** The factory for the client middleware. */
158 static class Factory
implements FlightClientMiddleware
.Factory
{
160 String extractedHeader
= null;
163 public FlightClientMiddleware
onCallStarted(CallInfo info
) {
164 return new ExtractingClientMiddleware(this);