]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / java / flight / flight-core / src / main / java / org / apache / arrow / flight / example / integration / MiddlewareScenario.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.arrow.flight.example.integration;
19
20 import java.nio.charset.StandardCharsets;
21 import java.util.Arrays;
22 import java.util.Collections;
23
24 import org.apache.arrow.flight.CallHeaders;
25 import org.apache.arrow.flight.CallInfo;
26 import org.apache.arrow.flight.CallStatus;
27 import org.apache.arrow.flight.FlightClient;
28 import org.apache.arrow.flight.FlightClientMiddleware;
29 import org.apache.arrow.flight.FlightDescriptor;
30 import org.apache.arrow.flight.FlightInfo;
31 import org.apache.arrow.flight.FlightProducer;
32 import org.apache.arrow.flight.FlightRuntimeException;
33 import org.apache.arrow.flight.FlightServer;
34 import org.apache.arrow.flight.FlightServerMiddleware;
35 import org.apache.arrow.flight.Location;
36 import org.apache.arrow.flight.NoOpFlightProducer;
37 import org.apache.arrow.flight.RequestContext;
38 import org.apache.arrow.memory.BufferAllocator;
39 import org.apache.arrow.vector.types.pojo.Schema;
40
41 /**
42 * Test an edge case in middleware: gRPC-Java consolidates headers and trailers if a call fails immediately. On the
43 * gRPC implementation side, we need to watch for this, or else we'll have a call with "no headers" if we only look
44 * for headers.
45 */
46 final class MiddlewareScenario implements Scenario {
47
48 private static final String HEADER = "x-middleware";
49 private static final String EXPECTED_HEADER_VALUE = "expected value";
50 private static final byte[] COMMAND_SUCCESS = "success".getBytes(StandardCharsets.UTF_8);
51
52 @Override
53 public FlightProducer producer(BufferAllocator allocator, Location location) {
54 return new NoOpFlightProducer() {
55 @Override
56 public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
57 if (descriptor.isCommand()) {
58 if (Arrays.equals(COMMAND_SUCCESS, descriptor.getCommand())) {
59 return new FlightInfo(new Schema(Collections.emptyList()), descriptor, Collections.emptyList(), -1, -1);
60 }
61 }
62 throw CallStatus.UNIMPLEMENTED.toRuntimeException();
63 }
64 };
65 }
66
67 @Override
68 public void buildServer(FlightServer.Builder builder) {
69 builder.middleware(FlightServerMiddleware.Key.of("test"), new InjectingServerMiddleware.Factory());
70 }
71
72 @Override
73 public void client(BufferAllocator allocator, Location location, FlightClient ignored) throws Exception {
74 final ExtractingClientMiddleware.Factory factory = new ExtractingClientMiddleware.Factory();
75 try (final FlightClient client = FlightClient.builder(allocator, location).intercept(factory).build()) {
76 // Should fail immediately
77 IntegrationAssertions.assertThrows(FlightRuntimeException.class,
78 () -> client.getInfo(FlightDescriptor.command(new byte[0])));
79 if (!EXPECTED_HEADER_VALUE.equals(factory.extractedHeader)) {
80 throw new AssertionError(
81 "Expected to extract the header value '" +
82 EXPECTED_HEADER_VALUE +
83 "', but found: " +
84 factory.extractedHeader);
85 }
86
87 // Should not fail
88 factory.extractedHeader = "";
89 client.getInfo(FlightDescriptor.command(COMMAND_SUCCESS));
90 if (!EXPECTED_HEADER_VALUE.equals(factory.extractedHeader)) {
91 throw new AssertionError(
92 "Expected to extract the header value '" +
93 EXPECTED_HEADER_VALUE +
94 "', but found: " +
95 factory.extractedHeader);
96 }
97 }
98 }
99
100 /** Middleware that inserts a constant value in outgoing requests. */
101 static class InjectingServerMiddleware implements FlightServerMiddleware {
102
103 private final String headerValue;
104
105 InjectingServerMiddleware(String incoming) {
106 this.headerValue = incoming;
107 }
108
109 @Override
110 public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
111 outgoingHeaders.insert("x-middleware", headerValue);
112 }
113
114 @Override
115 public void onCallCompleted(CallStatus status) {
116 }
117
118 @Override
119 public void onCallErrored(Throwable err) {
120 }
121
122 /** The factory for the server middleware. */
123 static class Factory implements FlightServerMiddleware.Factory<InjectingServerMiddleware> {
124
125 @Override
126 public InjectingServerMiddleware onCallStarted(CallInfo info, CallHeaders incomingHeaders,
127 RequestContext context) {
128 String incoming = incomingHeaders.get(HEADER);
129 return new InjectingServerMiddleware(incoming == null ? "" : incoming);
130 }
131 }
132 }
133
134 /** Middleware that pulls a value out of incoming responses. */
135 static class ExtractingClientMiddleware implements FlightClientMiddleware {
136
137 private final ExtractingClientMiddleware.Factory factory;
138
139 public ExtractingClientMiddleware(ExtractingClientMiddleware.Factory factory) {
140 this.factory = factory;
141 }
142
143 @Override
144 public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
145 outgoingHeaders.insert(HEADER, EXPECTED_HEADER_VALUE);
146 }
147
148 @Override
149 public void onHeadersReceived(CallHeaders incomingHeaders) {
150 this.factory.extractedHeader = incomingHeaders.get(HEADER);
151 }
152
153 @Override
154 public void onCallCompleted(CallStatus status) {
155 }
156
157 /** The factory for the client middleware. */
158 static class Factory implements FlightClientMiddleware.Factory {
159
160 String extractedHeader = null;
161
162 @Override
163 public FlightClientMiddleware onCallStarted(CallInfo info) {
164 return new ExtractingClientMiddleware(this);
165 }
166 }
167 }
168 }