diff --git a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractDirectGraphQlTransport.java b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractDirectGraphQlTransport.java index 359587f13..bdd3c6b54 100644 --- a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractDirectGraphQlTransport.java +++ b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractDirectGraphQlTransport.java @@ -73,6 +73,11 @@ public Flux executeSubscription(GraphQlRequest request) { }); } + @Override + public Mono executeFileUpload(GraphQlRequest request) { + throw new UnsupportedOperationException("File upload is not supported"); + } + private ExecutionGraphQlRequest toExecutionRequest(GraphQlRequest request) { return new DefaultExecutionGraphQlRequest( request.getDocument(), request.getOperationName(), request.getVariables(), request.getExtensions(), diff --git a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractGraphQlTesterBuilder.java b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractGraphQlTesterBuilder.java index 287770ea7..5a8ed8e15 100644 --- a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractGraphQlTesterBuilder.java +++ b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/AbstractGraphQlTesterBuilder.java @@ -176,7 +176,12 @@ public Flux executeSubscription(GraphQlRequest request) { .executeSubscription() .cast(GraphQlResponse.class); } - }; + + @Override + public Mono executeFileUpload(GraphQlRequest request) { + throw new UnsupportedOperationException("File upload is not supported"); + } + }; } diff --git a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/DefaultGraphQlTester.java b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/DefaultGraphQlTester.java index 36157ce0a..cc1327b20 100644 --- a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/DefaultGraphQlTester.java +++ b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/DefaultGraphQlTester.java @@ -18,11 +18,7 @@ import java.lang.reflect.Type; import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.function.Supplier; @@ -38,6 +34,7 @@ import org.springframework.graphql.GraphQlResponse; import org.springframework.graphql.ResponseError; import org.springframework.graphql.client.GraphQlTransport; +import org.springframework.graphql.client.MultipartClientGraphQlRequest; import org.springframework.graphql.support.DefaultGraphQlRequest; import org.springframework.graphql.support.DocumentSource; import org.springframework.lang.Nullable; @@ -127,7 +124,9 @@ private final class DefaultRequest implements Request { private final Map extensions = new LinkedHashMap<>(); - private DefaultRequest(String document) { + private final Map fileVariables = new LinkedHashMap<>(); + + private DefaultRequest(String document) { Assert.notNull(document, "`document` is required"); this.document = document; } @@ -144,7 +143,19 @@ public DefaultRequest variable(String name, @Nullable Object value) { return this; } - @Override + @Override + public DefaultRequest fileVariable(String name, Object value) { + this.fileVariables.put(name, value); + return this; + } + + @Override + public DefaultRequest fileVariables(Map variables) { + this.fileVariables.putAll(variables); + return this; + } + + @Override public DefaultRequest extension(String name, Object value) { this.extensions.put(name, value); return this; @@ -156,6 +167,16 @@ public Response execute() { return transport.execute(request()).map(response -> mapResponse(response, request())).block(responseTimeout); } + @Override + public Response executeFileUpload() { + return transport.executeFileUpload(requestFileUpload()).map(response -> mapResponse(response, requestFileUpload())).block(responseTimeout); + } + + @Override + public void executeFileUploadAndVerify() { + executeFileUpload().path("$.errors").pathDoesNotExist(); + } + @Override public void executeAndVerify() { execute().path("$.errors").pathDoesNotExist(); @@ -170,6 +191,10 @@ private GraphQlRequest request() { return new DefaultGraphQlRequest(this.document, this.operationName, this.variables, this.extensions); } + private GraphQlRequest requestFileUpload() { + return new MultipartClientGraphQlRequest(this.document, this.operationName, this.variables, this.extensions, new HashMap<>(), this.fileVariables); + } + private DefaultResponse mapResponse(GraphQlResponse response, GraphQlRequest request) { return new DefaultResponse(response, errorFilter, assertDecorator(request), jsonPathConfig); } diff --git a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/GraphQlTester.java b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/GraphQlTester.java index 203a25b67..acdbd3f64 100644 --- a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/GraphQlTester.java +++ b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/GraphQlTester.java @@ -18,9 +18,11 @@ import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.function.Consumer; import java.util.function.Predicate; +import org.springframework.graphql.client.GraphQlClient; import reactor.core.publisher.Flux; import org.springframework.core.ParameterizedTypeReference; @@ -149,6 +151,10 @@ interface Request> { */ T variable(String name, @Nullable Object value); + T fileVariable(String name, Object value); + + T fileVariables(Map variables); + /** * Add a value for a protocol extension. * @param name the protocol extension name @@ -166,7 +172,9 @@ interface Request> { */ Response execute(); - /** + void executeFileUploadAndVerify(); + + /** * Execute the GraphQL request and verify the response contains no errors. */ void executeAndVerify(); @@ -180,6 +188,8 @@ interface Request> { */ Subscription executeSubscription(); + Response executeFileUpload(); + } /** diff --git a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/WebTestClientTransport.java b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/WebTestClientTransport.java index 904697599..21fe4aa74 100644 --- a/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/WebTestClientTransport.java +++ b/spring-graphql-test/src/main/java/org/springframework/graphql/test/tester/WebTestClientTransport.java @@ -19,6 +19,9 @@ import java.util.Collections; import java.util.Map; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.reactive.function.BodyInserters; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,6 +33,8 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.util.Assert; +import static org.springframework.graphql.client.MultipartBodyCreator.convertRequestToMultipartData; + /** * {@code GraphQlTransport} for GraphQL over HTTP via {@link WebTestClient}. * @@ -70,6 +75,25 @@ public Mono execute(GraphQlRequest request) { return Mono.just(response); } + @Override + public Mono executeFileUpload(GraphQlRequest request) { + + Map responseMap = this.webTestClient.post() + .contentType(MediaType.MULTIPART_FORM_DATA) + .accept(MediaType.APPLICATION_JSON) + .body(BodyInserters.fromMultipartData(convertRequestToMultipartData(request))) + .exchange() + .expectStatus().isOk() + .expectHeader().contentTypeCompatibleWith(MediaType.APPLICATION_JSON) + .expectBody(MAP_TYPE) + .returnResult() + .getResponseBody(); + + responseMap = (responseMap != null ? responseMap : Collections.emptyMap()); + GraphQlResponse response = GraphQlTransport.createResponse(responseMap); + return Mono.just(response); + } + @Override public Flux executeSubscription(GraphQlRequest request) { throw new UnsupportedOperationException("Subscriptions not supported over HTTP"); diff --git a/spring-graphql-test/src/test/java/org/springframework/graphql/test/tester/HttpGraphQlTesterTests.java b/spring-graphql-test/src/test/java/org/springframework/graphql/test/tester/HttpGraphQlTesterTests.java new file mode 100644 index 000000000..d005bb894 --- /dev/null +++ b/spring-graphql-test/src/test/java/org/springframework/graphql/test/tester/HttpGraphQlTesterTests.java @@ -0,0 +1,67 @@ +package org.springframework.graphql.test.tester; + +import org.junit.jupiter.api.Test; +import org.springframework.core.io.ClassPathResource; +import org.springframework.graphql.server.webflux.GraphQlHttpHandler; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.ServerResponse; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.web.reactive.function.server.RouterFunctions.route; + +public class HttpGraphQlTesterTests { + + private static final String DOCUMENT = "{ Mutation }"; + + @Test + void shouldSendOneFile() { + MultipartHttpBuilderSetup testerSetup = new MultipartHttpBuilderSetup(); + + HttpGraphQlTester.Builder builder = testerSetup.initBuilder(); + HttpGraphQlTester tester = builder.build(); + tester.document(DOCUMENT) + .variable("existingVar", "itsValue") + .fileVariable("fileInput", new ClassPathResource("/foo.txt")) + .executeFileUpload(); + assertThat(testerSetup.getWebGraphQlRequest().getVariables().get("existingVar")).isEqualTo("itsValue"); + assertThat(testerSetup.getWebGraphQlRequest().getVariables().get("fileInput")).isNotNull(); + assertThat(((FilePart)testerSetup.getWebGraphQlRequest().getVariables().get("fileInput")).filename()).isEqualTo("foo.txt"); + } + + @Test + void shouldSendOneCollectionOfFiles() { + MultipartHttpBuilderSetup testerSetup = new MultipartHttpBuilderSetup(); + + HttpGraphQlTester.Builder builder = testerSetup.initBuilder(); + HttpGraphQlTester tester = builder.build(); + List resources = new ArrayList<>(); + resources.add(new ClassPathResource("/foo.txt")); + resources.add(new ClassPathResource("/bar.txt")); + tester.document(DOCUMENT) + .variable("existingVar", "itsValue") + .fileVariable("filesInput", resources) + .executeFileUpload(); + assertThat(testerSetup.getWebGraphQlRequest().getVariables().get("existingVar")).isEqualTo("itsValue"); + assertThat(testerSetup.getWebGraphQlRequest().getVariables().get("filesInput")).isNotNull(); + assertThat(((Collection)testerSetup.getWebGraphQlRequest().getVariables().get("filesInput")).size()).isEqualTo(2); + assertThat(((Collection)testerSetup.getWebGraphQlRequest().getVariables().get("filesInput")).stream().map(filePart -> filePart.filename()).collect(Collectors.toSet())).contains("foo.txt", "bar.txt"); + } + + private static class MultipartHttpBuilderSetup extends WebGraphQlTesterBuilderTests.WebBuilderSetup { + + @Override + public HttpGraphQlTester.Builder initBuilder() { + GraphQlHttpHandler handler = new GraphQlHttpHandler(webGraphQlHandler()); + RouterFunction routerFunction = route().POST("/**", handler::handleMultipartRequest).build(); + return HttpGraphQlTester.builder(WebTestClient.bindToRouterFunction(routerFunction).configureClient()); + } + + } +} diff --git a/spring-graphql-test/src/test/java/org/springframework/graphql/test/tester/WebGraphQlTesterBuilderTests.java b/spring-graphql-test/src/test/java/org/springframework/graphql/test/tester/WebGraphQlTesterBuilderTests.java index d2340015e..c9ada364c 100644 --- a/spring-graphql-test/src/test/java/org/springframework/graphql/test/tester/WebGraphQlTesterBuilderTests.java +++ b/spring-graphql-test/src/test/java/org/springframework/graphql/test/tester/WebGraphQlTesterBuilderTests.java @@ -204,7 +204,7 @@ private interface TesterBuilderSetup { } - private static class WebBuilderSetup implements TesterBuilderSetup { + static class WebBuilderSetup implements TesterBuilderSetup { @Nullable private WebGraphQlRequest request; diff --git a/spring-graphql-test/src/test/resources/bar.txt b/spring-graphql-test/src/test/resources/bar.txt new file mode 100644 index 000000000..e00821ec8 --- /dev/null +++ b/spring-graphql-test/src/test/resources/bar.txt @@ -0,0 +1 @@ +hello from bar here! \ No newline at end of file diff --git a/spring-graphql-test/src/test/resources/foo.txt b/spring-graphql-test/src/test/resources/foo.txt new file mode 100644 index 000000000..6e0f704fe --- /dev/null +++ b/spring-graphql-test/src/test/resources/foo.txt @@ -0,0 +1 @@ +hello here! \ No newline at end of file diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/AbstractGraphQlClientBuilder.java b/spring-graphql/src/main/java/org/springframework/graphql/client/AbstractGraphQlClientBuilder.java index 9ecb158a4..b4b04c140 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/AbstractGraphQlClientBuilder.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/AbstractGraphQlClientBuilder.java @@ -170,7 +170,7 @@ protected GraphQlClient buildGraphQlClient(GraphQlTransport transport) { } return new DefaultGraphQlClient( - this.documentSource, createExecuteChain(transport), createExecuteSubscriptionChain(transport)); + this.documentSource, createExecuteChain(transport), createFileUploadChain(transport), createExecuteSubscriptionChain(transport)); } /** @@ -195,7 +195,18 @@ private Chain createExecuteChain(GraphQlTransport transport) { .orElse(chain); } - private SubscriptionChain createExecuteSubscriptionChain(GraphQlTransport transport) { + private Chain createFileUploadChain(GraphQlTransport transport) { + + Chain chain = request -> transport.executeFileUpload(request).map(response -> + new DefaultClientGraphQlResponse(request, response, getEncoder(), getDecoder())); + + return this.interceptors.stream() + .reduce(GraphQlClientInterceptor::andThen) + .map(interceptor -> (Chain) (request) -> interceptor.intercept(request, chain)) + .orElse(chain); + } + + private SubscriptionChain createExecuteSubscriptionChain(GraphQlTransport transport) { SubscriptionChain chain = request -> transport.executeSubscription(request) .map(response -> new DefaultClientGraphQlResponse(request, response, getEncoder(), getDecoder())); diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultClientGraphQlRequest.java b/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultClientGraphQlRequest.java index ad576e45b..e4908ef13 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultClientGraphQlRequest.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultClientGraphQlRequest.java @@ -29,7 +29,7 @@ * @author Rossen Stoyanchev * @since 1.0.0 */ -final class DefaultClientGraphQlRequest extends DefaultGraphQlRequest implements ClientGraphQlRequest { +class DefaultClientGraphQlRequest extends DefaultGraphQlRequest implements ClientGraphQlRequest { private final Map attributes = new ConcurrentHashMap<>(); diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultGraphQlClient.java b/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultGraphQlClient.java index 929578c55..e8b5e1083 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultGraphQlClient.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultGraphQlClient.java @@ -42,20 +42,25 @@ final class DefaultGraphQlClient implements GraphQlClient { private final GraphQlClientInterceptor.Chain executeChain; - private final GraphQlClientInterceptor.SubscriptionChain executeSubscriptionChain; + private final GraphQlClientInterceptor.Chain fileUploadChain; + + private final GraphQlClientInterceptor.SubscriptionChain executeSubscriptionChain; DefaultGraphQlClient( DocumentSource documentSource, GraphQlClientInterceptor.Chain executeChain, + GraphQlClientInterceptor.Chain fileUploadChain, GraphQlClientInterceptor.SubscriptionChain executeSubscriptionChain) { Assert.notNull(documentSource, "DocumentSource is required"); Assert.notNull(executeChain, "GraphQlClientInterceptor.Chain is required"); - Assert.notNull(executeSubscriptionChain, "GraphQlClientInterceptor.SubscriptionChain is required"); + Assert.notNull(fileUploadChain, "GraphQlClientInterceptor.Chain is required"); + Assert.notNull(executeSubscriptionChain, "GraphQlClientInterceptor.SubscriptionChain is required"); this.documentSource = documentSource; this.executeChain = executeChain; - this.executeSubscriptionChain = executeSubscriptionChain; + this.fileUploadChain = fileUploadChain; + this.executeSubscriptionChain = executeSubscriptionChain; } @@ -96,7 +101,9 @@ private final class DefaultRequestSpec implements RequestSpec { private final Map extensions = new LinkedHashMap<>(); - DefaultRequestSpec(Mono documentMono) { + private final Map fileVariables = new LinkedHashMap<>(); + + DefaultRequestSpec(Mono documentMono) { Assert.notNull(documentMono, "'document' is required"); this.documentMono = documentMono; } @@ -119,6 +126,20 @@ public RequestSpec variables(Map variables) { return this; } + @Override + public DefaultRequestSpec fileVariable(String name, Object value) { + Assert.notNull(name, "'name' is required"); + Assert.notNull(value, "'value' is required"); + this.fileVariables.put(name, value); + return this; + } + + @Override + public RequestSpec fileVariables(Map files) { + this.fileVariables.putAll(files); + return this; + } + @Override public RequestSpec extension(String name, Object value) { this.extensions.put(name, value); @@ -161,6 +182,14 @@ public Mono execute() { ex -> Mono.error(new GraphQlTransportException(ex, request)))); } + @Override + public Mono executeFileUpload() { + return initFileUploadRequest().flatMap(request -> fileUploadChain.next(request) + .onErrorResume( + ex -> !(ex instanceof GraphQlClientException), + ex -> Mono.error(new GraphQlTransportException(ex, request)))); + } + @Override public Flux executeSubscription() { return initRequest().flatMapMany(request -> executeSubscriptionChain.next(request) @@ -174,6 +203,11 @@ private Mono initRequest() { new DefaultClientGraphQlRequest(document, this.operationName, this.variables, this.extensions, this.attributes)); } + private Mono initFileUploadRequest() { + return this.documentMono.map(document -> + new MultipartClientGraphQlRequest(document, this.operationName, this.variables, this.extensions, this.attributes, this.fileVariables)); + } + } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlClient.java b/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlClient.java index 42c6e49d5..7c18057c2 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlClient.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlClient.java @@ -151,6 +151,10 @@ interface RequestSpec { */ RequestSpec variables(Map variables); + RequestSpec fileVariable(String name, Object value); + + RequestSpec fileVariables(Map variables); + /** * Add a value for a protocol extension. * @param name the protocol extension name @@ -207,7 +211,9 @@ interface RequestSpec { */ Mono execute(); - /** + Mono executeFileUpload(); + + /** * Execute a "subscription" request and return a stream of responses. * @return a {@code Flux} with responses that provide further options for * decoding of each response. The {@code Flux} may terminate as follows: diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlTransport.java b/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlTransport.java index 44b0a7a70..9f9367e93 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlTransport.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/GraphQlTransport.java @@ -59,6 +59,7 @@ public interface GraphQlTransport { */ Flux executeSubscription(GraphQlRequest request); + Mono executeFileUpload(GraphQlRequest request); /** * Factory method to create {@link GraphQlResponse} from a GraphQL response diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/HttpGraphQlTransport.java b/spring-graphql/src/main/java/org/springframework/graphql/client/HttpGraphQlTransport.java index c82f8bbc5..eb8e352f8 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/HttpGraphQlTransport.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/HttpGraphQlTransport.java @@ -16,8 +16,14 @@ package org.springframework.graphql.client; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserters; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -29,6 +35,8 @@ import org.springframework.util.Assert; import org.springframework.web.reactive.function.client.WebClient; +import static org.springframework.graphql.client.MultipartBodyCreator.convertRequestToMultipartData; + /** * Transport to execute GraphQL requests over HTTP via {@link WebClient}. @@ -75,7 +83,19 @@ public Mono execute(GraphQlRequest request) { .map(ResponseMapGraphQlResponse::new); } - @Override + @Override + public Mono executeFileUpload(GraphQlRequest request) { + return this.webClient.post() + .contentType(MediaType.MULTIPART_FORM_DATA) + .accept(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL) + .body(BodyInserters.fromMultipartData(convertRequestToMultipartData(request))) + .retrieve() + .bodyToMono(MAP_TYPE) + .map(ResponseMapGraphQlResponse::new); + } + + + @Override public Flux executeSubscription(GraphQlRequest request) { throw new UnsupportedOperationException("Subscriptions not supported over HTTP"); } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/MultipartBodyCreator.java b/spring-graphql/src/main/java/org/springframework/graphql/client/MultipartBodyCreator.java new file mode 100644 index 000000000..4f3fe5b05 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/MultipartBodyCreator.java @@ -0,0 +1,60 @@ +package org.springframework.graphql.client; + +import org.springframework.graphql.GraphQlRequest; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.util.MultiValueMap; + +import java.util.*; +import java.util.function.BiConsumer; + +public final class MultipartBodyCreator { + + public static MultiValueMap convertRequestToMultipartData(GraphQlRequest request) { + MultipartClientGraphQlRequest multipartRequest = (MultipartClientGraphQlRequest) request; + MultipartBodyBuilder builder = new MultipartBodyBuilder(); + + Map> partMappings = new HashMap<>(); + Map operations = multipartRequest.toMap(); + Map variables = new HashMap<>(multipartRequest.getVariables()); + createFilePartsAndMapping(multipartRequest.getFileVariables(), variables, partMappings, builder::part); + operations.put("variables", variables); + builder.part("operations", operations); + + builder.part("map", partMappings); + return builder.build(); + } + + public static void createFilePartsAndMapping( + Map fileVariables, + Map variables, + Map> partMappings, + BiConsumer partConsumer) { + int partNumber = 0; + for (Map.Entry entry : fileVariables.entrySet()) { + Object resource = entry.getValue(); + String variableName = entry.getKey(); + if (resource instanceof Collection) { + List placeholders = new ArrayList<>(); + int inMappingNumber = 0; + for (Object fileResourceItem: (Collection)resource) { + placeholders.add(null); + String partName = "uploadPart" + partNumber; + partConsumer.accept(partName, fileResourceItem); + partMappings.put(partName, Collections.singletonList( + "variables." + variableName + "." + inMappingNumber + )); + partNumber++; + inMappingNumber++; + } + variables.put(variableName, placeholders); + } else { + String partName = "uploadPart" + partNumber; + partConsumer.accept(partName, resource); + variables.put(variableName, null); + partMappings.put(partName, Collections.singletonList("variables." + variableName)); + partNumber++; + } + } + } + +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/MultipartClientGraphQlRequest.java b/spring-graphql/src/main/java/org/springframework/graphql/client/MultipartClientGraphQlRequest.java new file mode 100644 index 000000000..b5375c931 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/MultipartClientGraphQlRequest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.client; + + +import org.springframework.lang.Nullable; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Multipart implementation of {@link ClientGraphQlRequest}. + * + * @author Nikita Konev + * @since 1.0.0 + */ +public final class MultipartClientGraphQlRequest extends DefaultClientGraphQlRequest implements ClientGraphQlRequest { + + private final Map fileVariables = new ConcurrentHashMap<>(); + + public MultipartClientGraphQlRequest( + String document, @Nullable String operationName, + Map variables, Map extensions, + Map attributes, + Map fileVariables) { + + super(document, operationName, variables, extensions, attributes); + this.fileVariables.putAll(fileVariables); + } + + public Map getFileVariables() { + return this.fileVariables; + } + +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/RSocketGraphQlTransport.java b/spring-graphql/src/main/java/org/springframework/graphql/client/RSocketGraphQlTransport.java index 324c1fedc..f79e56528 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/RSocketGraphQlTransport.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/RSocketGraphQlTransport.java @@ -87,6 +87,11 @@ public Flux executeSubscription(GraphQlRequest request) { .map(ResponseMapGraphQlResponse::new); } + @Override + public Mono executeFileUpload(GraphQlRequest request) { + throw new UnsupportedOperationException("File upload is not supported"); + } + @SuppressWarnings("unchecked") private Exception decodeErrors(GraphQlRequest request, RejectedException ex) { try { diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java b/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java index f89a00417..30a4390a2 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java @@ -154,6 +154,11 @@ public Flux executeSubscription(GraphQlRequest request) { return this.graphQlSessionMono.flatMapMany(session -> session.executeSubscription(request)); } + @Override + public Mono executeFileUpload(GraphQlRequest request) { + throw new UnsupportedOperationException("File upload is not supported"); + } + /** * Client {@code WebSocketHandler} for GraphQL that deals with WebSocket diff --git a/spring-graphql/src/main/java/org/springframework/graphql/coercing/webflux/UploadCoercing.java b/spring-graphql/src/main/java/org/springframework/graphql/coercing/webflux/UploadCoercing.java new file mode 100644 index 000000000..ea7f5dea0 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/coercing/webflux/UploadCoercing.java @@ -0,0 +1,32 @@ +package org.springframework.graphql.coercing.webflux; + +import graphql.schema.Coercing; +import graphql.schema.CoercingParseLiteralException; +import graphql.schema.CoercingParseValueException; +import graphql.schema.CoercingSerializeException; +import org.springframework.http.codec.multipart.FilePart; + +public class UploadCoercing implements Coercing { + + @Override + public Object serialize(Object dataFetcherResult) throws CoercingSerializeException { + throw new CoercingSerializeException("Upload is an input-only type"); + } + + @Override + public FilePart parseValue(Object input) throws CoercingParseValueException { + if (input instanceof FilePart) { + return (FilePart) input; + } + throw new CoercingParseValueException( + String.format("Expected 'FilePart' like object but was '%s'.", + input != null ? input.getClass() : null) + ); + } + + @Override + public FilePart parseLiteral(Object input) throws CoercingParseLiteralException { + throw new CoercingParseLiteralException("Parsing literal of 'MultipartFile' is not supported"); + } +} + diff --git a/spring-graphql/src/main/java/org/springframework/graphql/coercing/webmvc/UploadCoercing.java b/spring-graphql/src/main/java/org/springframework/graphql/coercing/webmvc/UploadCoercing.java new file mode 100644 index 000000000..80e43a039 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/coercing/webmvc/UploadCoercing.java @@ -0,0 +1,31 @@ +package org.springframework.graphql.coercing.webmvc; + +import graphql.schema.Coercing; +import graphql.schema.CoercingParseLiteralException; +import graphql.schema.CoercingParseValueException; +import graphql.schema.CoercingSerializeException; +import org.springframework.web.multipart.MultipartFile; + +public class UploadCoercing implements Coercing { + + @Override + public Object serialize(Object dataFetcherResult) throws CoercingSerializeException { + throw new CoercingSerializeException("Upload is an input-only type"); + } + + @Override + public MultipartFile parseValue(Object input) throws CoercingParseValueException { + if (input instanceof MultipartFile) { + return (MultipartFile) input; + } + throw new CoercingParseValueException( + String.format("Expected a 'MultipartFile' like object but was '%s'.", + input != null ? input.getClass() : null) + ); + } + + @Override + public MultipartFile parseLiteral(Object input) throws CoercingParseLiteralException { + throw new CoercingParseLiteralException("Parsing literal of 'MultipartFile' is not supported"); + } +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/support/MultipartVariableMapper.java b/spring-graphql/src/main/java/org/springframework/graphql/server/support/MultipartVariableMapper.java new file mode 100644 index 000000000..8aca83958 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/support/MultipartVariableMapper.java @@ -0,0 +1,87 @@ +package org.springframework.graphql.server.support; + +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Maps http request's file to GraphQL variables. + * Borrowed from graphql-java-kickstart. + */ +public class MultipartVariableMapper { + + private static final Pattern PERIOD = Pattern.compile("\\."); + + private static final MultipartVariableMapper.Mapper> MAP_MAPPER = + new MultipartVariableMapper.Mapper>() { + @Override + public

Object set(Map location, String target, P value) { + return location.put(target, value); + } + + @Override + public Object recurse(Map location, String target) { + return location.get(target); + } + }; + private static final MultipartVariableMapper.Mapper> LIST_MAPPER = + new MultipartVariableMapper.Mapper>() { + @Override + public

Object set(List location, String target, P value) { + return location.set(Integer.parseInt(target), value); + } + + @Override + public Object recurse(List location, String target) { + return location.get(Integer.parseInt(target)); + } + }; + + @SuppressWarnings({"unchecked", "rawtypes"}) + public static

void mapVariable(String objectPath, Map variables, P file) { + String[] segments = PERIOD.split(objectPath); + + if (segments.length < 2) { + throw new RuntimeException("object-path in map must have at least two segments"); + } else if (!"variables".equals(segments[0])) { + throw new RuntimeException("can only map into variables"); + } + + Object currentLocation = variables; + for (int i = 1; i < segments.length; i++) { + String segmentName = segments[i]; + MultipartVariableMapper.Mapper mapper = determineMapper(currentLocation, objectPath, segmentName); + + if (i == segments.length - 1) { + if (null != mapper.set(currentLocation, segmentName, file)) { + throw new RuntimeException("expected null value when mapping " + objectPath); + } + } else { + currentLocation = mapper.recurse(currentLocation, segmentName); + if (null == currentLocation) { + throw new RuntimeException( + "found null intermediate value when trying to map " + objectPath); + } + } + } + } + + private static MultipartVariableMapper.Mapper determineMapper( + Object currentLocation, String objectPath, String segmentName) { + if (currentLocation instanceof Map) { + return MAP_MAPPER; + } else if (currentLocation instanceof List) { + return LIST_MAPPER; + } + + throw new RuntimeException( + "expected a map or list at " + segmentName + " when trying to map " + objectPath); + } + + interface Mapper { + +

Object set(T location, String target, P value); + + Object recurse(T location, String target); + } +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlHttpHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlHttpHandler.java index e3bccd296..a22c17d59 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlHttpHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlHttpHandler.java @@ -17,11 +17,20 @@ package org.springframework.graphql.server.webflux; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Decoder; +import org.springframework.graphql.server.support.MultipartVariableMapper; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.Part; +import org.springframework.util.StringUtils; import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; @@ -31,12 +40,14 @@ import org.springframework.util.Assert; import org.springframework.web.reactive.function.server.ServerRequest; import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.util.function.Tuple2; /** * WebFlux.fn Handler for GraphQL over HTTP requests. * * @author Rossen Stoyanchev * @author Brian Clozel + * @author Nikita Konev * @since 1.0.0 */ public class GraphQlHttpHandler { @@ -46,11 +57,16 @@ public class GraphQlHttpHandler { private static final ParameterizedTypeReference> MAP_PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference>() {}; - private static final List SUPPORTED_MEDIA_TYPES = + private static final ParameterizedTypeReference>> LIST_PARAMETERIZED_TYPE_REF = + new ParameterizedTypeReference>>() {}; + + private static final List SUPPORTED_MEDIA_TYPES = Arrays.asList(MediaType.APPLICATION_GRAPHQL, MediaType.APPLICATION_JSON); private final WebGraphQlHandler graphQlHandler; + private final Decoder jsonDecoder; + /** * Create a new instance. * @param graphQlHandler common handler for GraphQL over HTTP requests @@ -58,8 +74,16 @@ public class GraphQlHttpHandler { public GraphQlHttpHandler(WebGraphQlHandler graphQlHandler) { Assert.notNull(graphQlHandler, "WebGraphQlHandler is required"); this.graphQlHandler = graphQlHandler; + this.jsonDecoder = new Jackson2JsonDecoder(); } + public GraphQlHttpHandler(WebGraphQlHandler graphQlHandler, Decoder jsonDecoder) { + Assert.notNull(graphQlHandler, "WebGraphQlHandler is required"); + Assert.notNull(jsonDecoder, "Decoder is required"); + this.graphQlHandler = graphQlHandler; + this.jsonDecoder = jsonDecoder; + } + /** * Handle GraphQL requests over HTTP. * @param serverRequest the incoming HTTP request @@ -88,6 +112,91 @@ public Mono handleRequest(ServerRequest serverRequest) { }); } + @SuppressWarnings("unchecked") + public Mono handleMultipartRequest(ServerRequest serverRequest) { + return serverRequest.multipartData() + .flatMap(multipartMultiMap -> { + Map allParts = multipartMultiMap.toSingleValueMap(); + + Optional operation = Optional.ofNullable(allParts.get("operations")); + Optional mapParam = Optional.ofNullable(allParts.get("map")); + + Decoder> mapJsonDecoder = (Decoder>) jsonDecoder; + Decoder>> listJsonDecoder = (Decoder>>) jsonDecoder; + + Mono> inputQueryMono = operation + .map(part -> mapJsonDecoder.decodeToMono( + part.content(), ResolvableType.forType(MAP_PARAMETERIZED_TYPE_REF), + MediaType.APPLICATION_JSON, null + )).orElse(Mono.just(new HashMap<>())); + + Mono>> fileMapInputMono = mapParam + .map(part -> listJsonDecoder.decodeToMono(part.content(), + ResolvableType.forType(LIST_PARAMETERIZED_TYPE_REF), + MediaType.APPLICATION_JSON, null + )).orElse(Mono.just(new HashMap<>())); + + return Mono.zip(inputQueryMono, fileMapInputMono) + .flatMap((Tuple2, Map>> objects) -> { + Map inputQuery = objects.getT1(); + Map> fileMapInput = objects.getT2(); + + final Map queryVariables = getFromMapOrEmpty(inputQuery, "variables"); + final Map extensions = getFromMapOrEmpty(inputQuery, "extensions"); + + fileMapInput.forEach((String fileKey, List objectPaths) -> { + Part part = allParts.get(fileKey); + if (part != null) { + Assert.isInstanceOf(FilePart.class, part, "Part should be of type FilePart"); + FilePart file = (FilePart) part; + objectPaths.forEach((String objectPath) -> { + MultipartVariableMapper.mapVariable( + objectPath, + queryVariables, + file + ); + }); + } + }); + + String query = (String) inputQuery.get("query"); + String opName = (String) inputQuery.get("operationName"); + + Map body = Map.of( + "query", query, "operationName", StringUtils.hasText(opName) ? opName : "", "variables", queryVariables, "extensions", extensions); + + WebGraphQlRequest graphQlRequest = new WebGraphQlRequest( + serverRequest.uri(), serverRequest.headers().asHttpHeaders(), + body, + serverRequest.exchange().getRequest().getId(), + serverRequest.exchange().getLocaleContext().getLocale()); + + if (logger.isDebugEnabled()) { + logger.debug("Executing: " + graphQlRequest); + } + return this.graphQlHandler.handleRequest(graphQlRequest); + }); + }) + .flatMap(response -> { + if (logger.isDebugEnabled()) { + logger.debug("Execution complete"); + } + ServerResponse.BodyBuilder builder = ServerResponse.ok(); + builder.headers(headers -> headers.putAll(response.getResponseHeaders())); + builder.contentType(selectResponseMediaType(serverRequest)); + return builder.bodyValue(response.toMap()); + }); + } + + @SuppressWarnings("unchecked") + private Map getFromMapOrEmpty(Map input, String key) { + if (input.containsKey(key)) { + return (Map)input.get(key); + } else { + return new HashMap<>(); + } + } + private static MediaType selectResponseMediaType(ServerRequest serverRequest) { for (MediaType accepted : serverRequest.headers().accept()) { if (SUPPORTED_MEDIA_TYPES.contains(accepted)) { diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandler.java index 4b8eced71..9e74a5b0c 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandler.java @@ -17,14 +17,25 @@ package org.springframework.graphql.server.webmvc; import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Type; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.HashMap; import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.graphql.server.support.MultipartVariableMapper; +import org.springframework.util.StringUtils; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartHttpServletRequest; import reactor.core.publisher.Mono; import org.springframework.context.i18n.LocaleContextHolder; @@ -46,6 +57,7 @@ * * @author Rossen Stoyanchev * @author Brian Clozel + * @author Nikita Konev * @since 1.0.0 */ public class GraphQlHttpHandler { @@ -55,6 +67,9 @@ public class GraphQlHttpHandler { private static final ParameterizedTypeReference> MAP_PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference>() {}; + private static final ParameterizedTypeReference>> LIST_PARAMETERIZED_TYPE_REF = + new ParameterizedTypeReference>>() {}; + private static final List SUPPORTED_MEDIA_TYPES = Arrays.asList(MediaType.APPLICATION_GRAPHQL, MediaType.APPLICATION_JSON); @@ -62,6 +77,8 @@ public class GraphQlHttpHandler { private final WebGraphQlHandler graphQlHandler; + private final PartReader partReader; + /** * Create a new instance. * @param graphQlHandler common handler for GraphQL over HTTP requests @@ -69,8 +86,16 @@ public class GraphQlHttpHandler { public GraphQlHttpHandler(WebGraphQlHandler graphQlHandler) { Assert.notNull(graphQlHandler, "WebGraphQlHandler is required"); this.graphQlHandler = graphQlHandler; + this.partReader = new JacksonPartReader(new ObjectMapper()); } + public GraphQlHttpHandler(WebGraphQlHandler graphQlHandler, PartReader partReader) { + Assert.notNull(graphQlHandler, "WebGraphQlHandler is required"); + Assert.notNull(partReader, "PartConverter is required"); + this.graphQlHandler = graphQlHandler; + this.partReader = partReader; + } + /** * Handle GraphQL requests over HTTP. * @param serverRequest the incoming HTTP request @@ -102,6 +127,96 @@ public ServerResponse handleRequest(ServerRequest serverRequest) throws ServletE return ServerResponse.async(responseMono); } + public ServerResponse handleMultipartRequest(ServerRequest serverRequest) throws ServletException { + HttpServletRequest httpServletRequest = serverRequest.servletRequest(); + + Map inputQuery = Optional.ofNullable(this.>deserializePart( + httpServletRequest, + "operations", + MAP_PARAMETERIZED_TYPE_REF.getType() + )).orElse(new HashMap<>()); + + final Map queryVariables = getFromMapOrEmpty(inputQuery, "variables"); + final Map extensions = getFromMapOrEmpty(inputQuery, "extensions"); + + Map fileParams = readMultipartFiles(httpServletRequest); + + Map> fileMappings = Optional.ofNullable(this.>>deserializePart( + httpServletRequest, + "map", + LIST_PARAMETERIZED_TYPE_REF.getType() + )).orElse(new HashMap<>()); + + fileMappings.forEach((String fileKey, List objectPaths) -> { + MultipartFile file = fileParams.get(fileKey); + if (file != null) { + objectPaths.forEach((String objectPath) -> { + MultipartVariableMapper.mapVariable( + objectPath, + queryVariables, + file + ); + }); + } + }); + + String query = (String) inputQuery.get("query"); + String opName = (String) inputQuery.get("operationName"); + + Map body = Map.of( + "query", query, "operationName", StringUtils.hasText(opName) ? opName : "", "variables", queryVariables, "extensions", extensions); + + WebGraphQlRequest graphQlRequest = new WebGraphQlRequest( + serverRequest.uri(), serverRequest.headers().asHttpHeaders(), + body, + this.idGenerator.generateId().toString(), LocaleContextHolder.getLocale()); + + if (logger.isDebugEnabled()) { + logger.debug("Executing: " + graphQlRequest); + } + + Mono responseMono = this.graphQlHandler.handleRequest(graphQlRequest) + .map(response -> { + if (logger.isDebugEnabled()) { + logger.debug("Execution complete"); + } + ServerResponse.BodyBuilder builder = ServerResponse.ok(); + builder.headers(headers -> headers.putAll(response.getResponseHeaders())); + builder.contentType(selectResponseMediaType(serverRequest)); + return builder.body(response.toMap()); + }); + + return ServerResponse.async(responseMono); + } + + private T deserializePart(HttpServletRequest httpServletRequest, String name, Type type) { + try { + Part part = httpServletRequest.getPart(name); + if (part == null) { + return null; + } + return partReader.readPart(part, type); + } catch (IOException | ServletException e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + private Map getFromMapOrEmpty(Map input, String key) { + if (input.containsKey(key)) { + return (Map)input.get(key); + } else { + return new HashMap<>(); + } + } + + private static Map readMultipartFiles(HttpServletRequest httpServletRequest) { + Assert.isInstanceOf(MultipartHttpServletRequest.class, httpServletRequest, + "Request should be of type MultipartHttpServletRequest"); + MultipartHttpServletRequest multipartHttpServletRequest = (MultipartHttpServletRequest) httpServletRequest; + return multipartHttpServletRequest.getFileMap(); + } + private static Map readBody(ServerRequest request) throws ServletException { try { return request.body(MAP_PARAMETERIZED_TYPE_REF); diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/JacksonPartReader.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/JacksonPartReader.java new file mode 100644 index 000000000..e9c6d4d2f --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/JacksonPartReader.java @@ -0,0 +1,37 @@ +package org.springframework.graphql.server.webmvc; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.core.GenericTypeResolver; + +import javax.servlet.http.Part; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Type; + +public class JacksonPartReader implements PartReader { + + private final ObjectMapper objectMapper; + + public JacksonPartReader(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + } + + @Override + public T readPart(Part part, Type targetType) { + try(InputStream inputStream = part.getInputStream()) { + try { + JavaType javaType = getJavaType(targetType); + return objectMapper.readValue(inputStream, javaType); + } catch (IOException e) { + throw new RuntimeException(e); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private JavaType getJavaType(Type type) { + return this.objectMapper.constructType(GenericTypeResolver.resolveType(type, (Class)null)); + } +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/PartReader.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/PartReader.java new file mode 100644 index 000000000..5cff8a981 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/PartReader.java @@ -0,0 +1,8 @@ +package org.springframework.graphql.server.webmvc; + +import javax.servlet.http.Part; +import java.lang.reflect.Type; + +public interface PartReader { + T readPart(Part part, Type targetType); +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/client/GraphQlClientTestSupport.java b/spring-graphql/src/test/java/org/springframework/graphql/client/GraphQlClientTestSupport.java index 9c0173635..e612a8a00 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/client/GraphQlClientTestSupport.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/client/GraphQlClientTestSupport.java @@ -86,6 +86,11 @@ public Flux executeSubscription(GraphQlRequest request) { return Flux.error(new UnsupportedOperationException()); } + @Override + public Mono executeFileUpload(GraphQlRequest request) { + throw new UnsupportedOperationException("File upload is not supported"); + } + } } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/client/HttpGraphQlClientTests.java b/spring-graphql/src/test/java/org/springframework/graphql/client/HttpGraphQlClientTests.java new file mode 100644 index 000000000..e3531c265 --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/client/HttpGraphQlClientTests.java @@ -0,0 +1,81 @@ +package org.springframework.graphql.client; + +import org.junit.jupiter.api.Test; +import org.springframework.core.io.ClassPathResource; +import org.springframework.graphql.server.webflux.GraphQlHttpHandler; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.test.web.reactive.server.HttpHandlerConnector; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.HandlerStrategies; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.web.reactive.function.server.RouterFunctions.route; + +public class HttpGraphQlClientTests { + + private static final String DOCUMENT = "{ Mutation }"; + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + @Test + void shouldSendOneFile() { + MultipartHttpBuilderSetup clientSetup = new MultipartHttpBuilderSetup(); + + // Original header value + HttpGraphQlClient.Builder builder = clientSetup.initBuilder(); + + HttpGraphQlClient client = builder.build(); + client.document(DOCUMENT) + .variable("existingVar", "itsValue") + .fileVariable("fileInput", new ClassPathResource("/foo.txt")) + .executeFileUpload().block(TIMEOUT); + assertThat(clientSetup.getActualRequest().getVariables().get("existingVar")).isEqualTo("itsValue"); + assertThat(clientSetup.getActualRequest().getVariables().get("fileInput")).isNotNull(); + assertThat(((FilePart)clientSetup.getActualRequest().getVariables().get("fileInput")).filename()).isEqualTo("foo.txt"); + } + + @Test + void shouldSendOneCollectionOfFiles() { + MultipartHttpBuilderSetup clientSetup = new MultipartHttpBuilderSetup(); + + // Original header value + HttpGraphQlClient.Builder builder = clientSetup.initBuilder(); + + HttpGraphQlClient client = builder.build(); + List resources = new ArrayList<>(); + resources.add(new ClassPathResource("/foo.txt")); + resources.add(new ClassPathResource("/bar.txt")); + + client.document(DOCUMENT) + .variable("existingVar", "itsValue") + .fileVariable("filesInput", resources) + .executeFileUpload().block(TIMEOUT); + assertThat(clientSetup.getActualRequest().getVariables().get("existingVar")).isEqualTo("itsValue"); + assertThat(clientSetup.getActualRequest().getVariables().get("filesInput")).isNotNull(); + assertThat(((Collection)clientSetup.getActualRequest().getVariables().get("filesInput")).size()).isEqualTo(2); + assertThat(((Collection)clientSetup.getActualRequest().getVariables().get("filesInput")).stream().map(filePart -> filePart.filename()).collect(Collectors.toSet())).contains("foo.txt", "bar.txt"); + } + + private static class MultipartHttpBuilderSetup extends WebGraphQlClientBuilderTests.AbstractBuilderSetup { + + @Override + public HttpGraphQlClient.Builder initBuilder() { + GraphQlHttpHandler handler = new GraphQlHttpHandler(webGraphQlHandler()); + RouterFunction routerFunction = route().POST("/**", handler::handleMultipartRequest).build(); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction, HandlerStrategies.withDefaults()); + HttpHandlerConnector connector = new HttpHandlerConnector(httpHandler); + return HttpGraphQlClient.builder(WebClient.builder().clientConnector(connector)); + } + + } +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/client/MultipartBodyCreatorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/client/MultipartBodyCreatorTests.java new file mode 100644 index 000000000..bff9798ee --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/client/MultipartBodyCreatorTests.java @@ -0,0 +1,88 @@ +package org.springframework.graphql.client; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpEntity; +import org.springframework.util.MultiValueMap; + +import java.util.*; + +public class MultipartBodyCreatorTests { + + @Test + public void shouldGenerateVariableForOneFile() { + Map variables = new HashMap<>(); + variables.put("existingVar", "itsValue"); + MultipartClientGraphQlRequest multipartClientGraphQlRequest = new MultipartClientGraphQlRequest( + "mockDoc", + "opName", + variables, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.singletonMap("fileInput", new ClassPathResource("/foo.txt")) + ); + MultiValueMap stringMultiValueMap = MultipartBodyCreator.convertRequestToMultipartData(multipartClientGraphQlRequest); + + HttpEntity operations = (HttpEntity) stringMultiValueMap.get("operations").get(0); + Map operationsBody = (Map) operations.getBody(); + Assertions.assertEquals("mockDoc", operationsBody.get("query")); + Assertions.assertEquals("opName", operationsBody.get("operationName")); + Map resultVariables = (Map) operationsBody.get("variables"); + Assertions.assertTrue(resultVariables.containsKey("fileInput")); + Assertions.assertNull(resultVariables.get("fileInput")); + Assertions.assertEquals("itsValue", resultVariables.get("existingVar")); + + HttpEntity mappings = (HttpEntity) stringMultiValueMap.get("map").get(0); + Map mappingsBody = (Map) mappings.getBody(); + Assertions.assertTrue((((List)mappingsBody.get("uploadPart0")).containsAll(Collections.singletonList("variables.fileInput")))); + + HttpEntity filePart = (HttpEntity) stringMultiValueMap.get("uploadPart0").get(0); + Assertions.assertTrue(filePart.getBody() instanceof ClassPathResource); + } + + @Test + public void shouldGenerateVariableForCollectionOfFiles() { + Map variables = new HashMap<>(); + variables.put("existingVar", "itsValue"); + List resources = new ArrayList<>(); + resources.add(new ClassPathResource("/foo.txt")); + resources.add(new ClassPathResource("/bar.txt")); + + MultipartClientGraphQlRequest multipartClientGraphQlRequest = new MultipartClientGraphQlRequest( + "mockDoc", + "opName", + variables, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.singletonMap("fileInput", resources) + ); + MultiValueMap stringMultiValueMap = MultipartBodyCreator.convertRequestToMultipartData(multipartClientGraphQlRequest); + + HttpEntity operations = (HttpEntity) stringMultiValueMap.get("operations").get(0); + Map operationsBody = (Map) operations.getBody(); + Assertions.assertEquals("mockDoc", operationsBody.get("query")); + Assertions.assertEquals("opName", operationsBody.get("operationName")); + Map resultVariables = (Map) operationsBody.get("variables"); + Assertions.assertTrue(resultVariables.containsKey("fileInput")); + List fileInputValues = (List) resultVariables.get("fileInput"); + Assertions.assertNotNull(fileInputValues); + Assertions.assertEquals(2, fileInputValues.size()); + Assertions.assertNull(fileInputValues.get(0)); + Assertions.assertNull(fileInputValues.get(1)); + + Assertions.assertEquals("itsValue", resultVariables.get("existingVar")); + + HttpEntity mappings = (HttpEntity) stringMultiValueMap.get("map").get(0); + Map mappingsBody = (Map) mappings.getBody(); + Assertions.assertTrue((((List)mappingsBody.get("uploadPart0")).containsAll(Collections.singletonList("variables.fileInput.0")))); + Assertions.assertTrue((((List)mappingsBody.get("uploadPart1")).containsAll(Collections.singletonList("variables.fileInput.1")))); + + HttpEntity filePart0 = (HttpEntity) stringMultiValueMap.get("uploadPart0").get(0); + Assertions.assertTrue(filePart0.getBody() instanceof ClassPathResource); + + HttpEntity filePart1 = (HttpEntity) stringMultiValueMap.get("uploadPart1").get(0); + Assertions.assertTrue(filePart1.getBody() instanceof ClassPathResource); + + } +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/client/WebGraphQlClientBuilderTests.java b/spring-graphql/src/test/java/org/springframework/graphql/client/WebGraphQlClientBuilderTests.java index 447f3532f..67827bcc6 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/client/WebGraphQlClientBuilderTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/client/WebGraphQlClientBuilderTests.java @@ -238,7 +238,7 @@ private interface ClientBuilderSetup { } - private abstract static class AbstractBuilderSetup implements ClientBuilderSetup { + abstract static class AbstractBuilderSetup implements ClientBuilderSetup { @Nullable private WebGraphQlRequest graphQlRequest; diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlHttpHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlHttpHandlerTests.java index 83484b660..a5a20e64f 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlHttpHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlHttpHandlerTests.java @@ -15,14 +15,26 @@ */ package org.springframework.graphql.server.webflux; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.Map; +import java.nio.file.Path; +import java.util.*; +import java.util.stream.Collectors; import com.jayway.jsonpath.DocumentContext; import com.jayway.jsonpath.JsonPath; +import graphql.schema.GraphQLScalarType; import org.junit.jupiter.api.Test; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.graphql.coercing.webflux.UploadCoercing; +import org.springframework.http.HttpHeaders; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.Part; +import org.springframework.util.LinkedMultiValueMap; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.graphql.GraphQlSetup; @@ -39,6 +51,7 @@ import org.springframework.web.server.ServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.graphql.client.MultipartBodyCreator.createFilePartsAndMapping; /** * Tests for {@link GraphQlHttpHandler}. @@ -49,8 +62,9 @@ public class GraphQlHttpHandlerTests { private final GraphQlHttpHandler greetingHandler = GraphQlSetup.schemaContent("type Query { greeting: String }") .queryFetcher("greeting", (env) -> "Hello").toHttpHandlerWebFlux(); + private final Jackson2JsonEncoder jackson2JsonEncoder = new Jackson2JsonEncoder(); - @Test + @Test void shouldProduceApplicationJsonByDefault() { MockServerHttpRequest httpRequest = MockServerHttpRequest.post("/") .contentType(MediaType.APPLICATION_JSON).accept(MediaType.ALL).build(); @@ -61,6 +75,69 @@ void shouldProduceApplicationJsonByDefault() { assertThat(httpResponse.getHeaders().getContentType()).isEqualTo(MediaType.APPLICATION_JSON); } + @Test + void shouldPassFile() { + GraphQlHttpHandler handler = GraphQlSetup.schemaContent( + "type Query { ping: String } \n" + + "scalar Upload\n" + + "type Mutation {\n" + + " fileUpload(fileInput: Upload!): String!\n" + + "}") + .mutationFetcher("fileUpload", (env) -> ((FilePart) env.getVariables().get("fileInput")).filename()) + .runtimeWiring(builder -> builder.scalar(GraphQLScalarType.newScalar() + .name("Upload") + .coercing(new UploadCoercing()) + .build())) + .toHttpHandlerWebFlux(); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.post("/") + .contentType(MediaType.MULTIPART_FORM_DATA).accept(MediaType.ALL) + .build(); + + MockServerHttpResponse httpResponse = handleMultipartRequest( + httpRequest, handler, "mutation FileUpload($fileInput: Upload!) " + + "{fileUpload(fileInput: $fileInput) }", + Collections.emptyMap(), + Collections.singletonMap("fileInput", new ClassPathResource("/foo.txt")) + ); + + assertThat(httpResponse.getBodyAsString().block()) + .isEqualTo("{\"data\":{\"fileUpload\":\"foo.txt\"}}"); + } + + @Test + void shouldPassListOfFiles() { + GraphQlHttpHandler handler = GraphQlSetup.schemaContent( + "type Query { ping: String } \n" + + "scalar Upload\n" + + "type Mutation {\n" + + " multipleFilesUpload(multipleFileInputs: [Upload!]!): [String!]!\n" + + "}") + .mutationFetcher("multipleFilesUpload", (env) -> ((Collection) env.getVariables().get("multipleFileInputs")).stream().map(FilePart::filename).collect(Collectors.toList())) + .runtimeWiring(builder -> builder.scalar(GraphQLScalarType.newScalar() + .name("Upload") + .coercing(new UploadCoercing()) + .build())) + .toHttpHandlerWebFlux(); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.post("/") + .contentType(MediaType.MULTIPART_FORM_DATA).accept(MediaType.ALL) + .build(); + + Collection resources = new ArrayList<>(); + resources.add(new ClassPathResource("/foo.txt")); + resources.add(new ClassPathResource("/bar.txt")); + MockServerHttpResponse httpResponse = handleMultipartRequest( + httpRequest, handler, "mutation MultipleFilesUpload($multipleFileInputs: [Upload!]!) " + + "{multipleFilesUpload(multipleFileInputs: $multipleFileInputs) }", + Collections.emptyMap(), + Collections.singletonMap("multipleFileInputs", resources) + ); + + assertThat(httpResponse.getBodyAsString().block()) + .isEqualTo("{\"data\":{\"multipleFilesUpload\":[\"foo.txt\",\"bar.txt\"]}}"); + } + @Test void shouldProduceApplicationGraphQl() { MockServerHttpRequest httpRequest = MockServerHttpRequest.post("/") @@ -135,6 +212,56 @@ private MockServerHttpResponse handleRequest( return exchange.getResponse(); } + private MockServerHttpResponse handleMultipartRequest( + MockServerHttpRequest httpRequest, GraphQlHttpHandler handler, String body, + Map requestVariables, Map files) { + + MockServerWebExchange exchange = MockServerWebExchange.from(httpRequest); + + LinkedMultiValueMap parts = new LinkedMultiValueMap<>(); + + Map> partMappings = new HashMap<>(); + Map operations = new HashMap<>(); + operations.put("query", body); + Map variables = new HashMap<>(requestVariables); + createFilePartsAndMapping(files, variables, partMappings, (partName, resource) -> addFilePart(parts, partName, (Resource) resource)); + operations.put("variables", variables); + addJsonEncodedPart(parts, "operations", operations); + + addJsonEncodedPart(parts, "map", partMappings); + + MockServerRequest serverRequest = MockServerRequest.builder() + .exchange(exchange) + .uri(((ServerWebExchange) exchange).getRequest().getURI()) + .method(((ServerWebExchange) exchange).getRequest().getMethod()) + .headers(((ServerWebExchange) exchange).getRequest().getHeaders()) + .body(Mono.just(parts)); + + handler.handleMultipartRequest(serverRequest) + .flatMap(response -> response.writeTo(exchange, new DefaultContext())) + .block(); + + return exchange.getResponse(); + } + + private void addJsonEncodedPart(LinkedMultiValueMap parts, String name, Object toSerialize) { + ResolvableType resolvableType = ResolvableType.forClass(HashMap.class); + Flux bufferFlux = jackson2JsonEncoder.encode( + Mono.just(toSerialize), + DefaultDataBufferFactory.sharedInstance, + resolvableType, + MediaType.APPLICATION_JSON, + null + ); + TestPart part = new TestPart(name, bufferFlux); + parts.add(name, part); + } + + private void addFilePart(LinkedMultiValueMap parts, String name, Resource resource) { + Flux dataBufferFlux = DataBufferUtils.read(resource, DefaultDataBufferFactory.sharedInstance, 1024); + TestFilePart filePart = new TestFilePart(name, resource.getFilename(), dataBufferFlux); + parts.add(name, filePart); + } private static class DefaultContext implements ServerResponse.Context { @@ -150,4 +277,72 @@ public List viewResolvers() { } + private static class TestPart implements Part { + + private final String name; + + + private final Flux content; + + private TestPart(String name, Flux content) { + this.name = name; + this.content = content; + } + + @Override + public String name() { + return name; + } + + @Override + public HttpHeaders headers() { + return new HttpHeaders(); + } + + @Override + public Flux content() { + return content; + } + } + + private static class TestFilePart implements FilePart { + + private final String name; + + private final String filename; + + private final Flux content; + + private TestFilePart(String name, String filename, Flux content) { + this.name = name; + this.filename = filename; + this.content = content; + } + + @Override + public String name() { + return name; + } + + @Override + public HttpHeaders headers() { + return new HttpHeaders(); + } + + @Override + public Flux content() { + return content; + } + + @Override + public String filename() { + return filename; + } + + @Override + public Mono transferTo(Path dest) { + return Mono.error(new RuntimeException("Not implemented")); + } + } + } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandlerTests.java index ef9b1a05e..9634af003 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandlerTests.java @@ -16,31 +16,37 @@ package org.springframework.graphql.server.webmvc; import java.io.IOException; +import java.io.InputStream; import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.UUID; +import java.util.*; +import java.util.stream.Collectors; import javax.servlet.ServletException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.jayway.jsonpath.DocumentContext; import com.jayway.jsonpath.JsonPath; +import graphql.schema.GraphQLScalarType; import org.junit.jupiter.api.Test; import org.springframework.context.i18n.LocaleContextHolder; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.coercing.webmvc.UploadCoercing; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.*; +import org.springframework.web.multipart.MultipartFile; import org.springframework.web.servlet.function.AsyncServerResponse; import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.springframework.graphql.client.MultipartBodyCreator.createFilePartsAndMapping; /** * Tests for {@link GraphQlHttpHandler}. @@ -55,6 +61,8 @@ public class GraphQlHttpHandlerTests { private final GraphQlHttpHandler greetingHandler = GraphQlSetup.schemaContent("type Query { greeting: String }") .queryFetcher("greeting", (env) -> "Hello").toHttpHandler(); + private final ObjectMapper objectMapper = new ObjectMapper(); + @Test void shouldProduceApplicationJsonByDefault() throws Exception { MockHttpServletRequest servletRequest = createServletRequest("{\"query\":\"{ greeting }\"}", "*/*"); @@ -95,6 +103,65 @@ void locale() throws Exception { } } + @Test + void shouldPassFile() throws Exception { + GraphQlHttpHandler handler = GraphQlSetup.schemaContent( + "type Query { ping: String } \n" + + "scalar Upload\n" + + "type Mutation {\n" + + " fileUpload(fileInput: Upload!): String!\n" + + "}") + .mutationFetcher("fileUpload", (env) -> ((MultipartFile) env.getVariables().get("fileInput")).getOriginalFilename()) + .runtimeWiring(builder -> builder.scalar(GraphQLScalarType.newScalar() + .name("Upload") + .coercing(new UploadCoercing()) + .build())) + .toHttpHandler(); + MockHttpServletRequest servletRequest = createMultipartServletRequest( + "mutation FileUpload($fileInput: Upload!) " + + "{fileUpload(fileInput: $fileInput) }", + MediaType.APPLICATION_GRAPHQL_VALUE, + Collections.singletonMap("fileInput", new ClassPathResource("/foo.txt")) + ); + + MockHttpServletResponse servletResponse = handleMultipartRequest(servletRequest, handler); + + assertThat(servletResponse.getContentAsString()) + .isEqualTo("{\"data\":{\"fileUpload\":\"foo.txt\"}}"); + } + + @Test + void shouldPassListOfFiles() throws Exception { + GraphQlHttpHandler handler = GraphQlSetup.schemaContent( + "type Query { ping: String } \n" + + "scalar Upload\n" + + "type Mutation {\n" + + " multipleFilesUpload(multipleFileInputs: [Upload!]!): [String!]!\n" + + "}") + .mutationFetcher("multipleFilesUpload", (env) -> ((Collection) env.getVariables().get("multipleFileInputs")).stream().map(multipartFile -> multipartFile.getOriginalFilename()).collect(Collectors.toList())) + .runtimeWiring(builder -> builder.scalar(GraphQLScalarType.newScalar() + .name("Upload") + .coercing(new UploadCoercing()) + .build())) + .toHttpHandler(); + + Collection resources = new ArrayList<>(); + resources.add(new ClassPathResource("/foo.txt")); + resources.add(new ClassPathResource("/bar.txt")); + + MockHttpServletRequest servletRequest = createMultipartServletRequest( + "mutation MultipleFilesUpload($multipleFileInputs: [Upload!]!) " + + "{multipleFilesUpload(multipleFileInputs: $multipleFileInputs) }", + MediaType.APPLICATION_GRAPHQL_VALUE, + Collections.singletonMap("multipleFileInputs", resources) + ); + + MockHttpServletResponse servletResponse = handleMultipartRequest(servletRequest, handler); + + assertThat(servletResponse.getContentAsString()) + .isEqualTo("{\"data\":{\"multipleFilesUpload\":[\"foo.txt\",\"bar.txt\"]}}"); + } + @Test void shouldSetExecutionId() throws Exception { GraphQlHttpHandler handler = GraphQlSetup.schemaContent("type Query { showId: ID! }") @@ -118,7 +185,58 @@ private MockHttpServletRequest createServletRequest(String query, String accept) return servletRequest; } - private MockHttpServletResponse handleRequest( + private MockHttpServletRequest createMultipartServletRequest(String query, String accept, Map files) { + MockMultipartHttpServletRequest servletRequest = new MockMultipartHttpServletRequest(); + servletRequest.addHeader("Accept", accept); + servletRequest.setAsyncSupported(true); + + Map> partMappings = new HashMap<>(); + Map operations = new HashMap<>(); + operations.put("query", query); + Map variables = new HashMap<>(); + createFilePartsAndMapping(files, variables, partMappings, + (partName, objectResource) -> servletRequest.addFile(getMultipartFile(partName, objectResource)) + ); + operations.put("variables", variables); + + servletRequest.addPart(new MockPart("operations", getJsonArray(operations))); + servletRequest.addPart(new MockPart("map", getJsonArray(partMappings))); + + return servletRequest; + } + + private MockMultipartFile getMultipartFile(String partName, Object objectResource) { + Resource resource = (Resource) objectResource; + try { + return new MockMultipartFile(partName, resource.getFilename(), null, resource.getInputStream()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private byte[] getFileByteArray(Resource resource) { + try { + byte[] targetArray = new byte[(int)resource.getFile().length()]; + try(InputStream inputStream = resource.getInputStream()) { + inputStream.read(targetArray); + return targetArray; + } catch (IOException e) { + throw new RuntimeException(e); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private byte[] getJsonArray(Object o) { + try { + return objectMapper.writeValueAsBytes(o); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private MockHttpServletResponse handleRequest( MockHttpServletRequest servletRequest, GraphQlHttpHandler handler) throws ServletException, IOException { ServerRequest request = ServerRequest.create(servletRequest, MESSAGE_READERS); @@ -129,6 +247,16 @@ private MockHttpServletResponse handleRequest( return servletResponse; } + private MockHttpServletResponse handleMultipartRequest( + MockHttpServletRequest servletRequest, GraphQlHttpHandler handler) throws ServletException, IOException { + + ServerRequest request = ServerRequest.create(servletRequest, MESSAGE_READERS); + ServerResponse response = ((AsyncServerResponse) handler.handleMultipartRequest(request)).block(); + + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + response.writeTo(servletRequest, servletResponse, new DefaultContext()); + return servletResponse; + } private static class DefaultContext implements ServerResponse.Context { diff --git a/spring-graphql/src/test/resources/bar.txt b/spring-graphql/src/test/resources/bar.txt new file mode 100644 index 000000000..e00821ec8 --- /dev/null +++ b/spring-graphql/src/test/resources/bar.txt @@ -0,0 +1 @@ +hello from bar here! \ No newline at end of file diff --git a/spring-graphql/src/test/resources/foo.txt b/spring-graphql/src/test/resources/foo.txt new file mode 100644 index 000000000..6e0f704fe --- /dev/null +++ b/spring-graphql/src/test/resources/foo.txt @@ -0,0 +1 @@ +hello here! \ No newline at end of file