diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java index 6153a45f6..6e6ca1aaa 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java @@ -92,11 +92,7 @@ public Object get(DataFetchingEnvironment env) throws Exception { Object value = snapshot.wrap(() -> this.delegate.get(env)).call(); if (value instanceof DataFetcherResult dataFetcherResult) { - Object adapted = updateValue(dataFetcherResult.getData(), snapshot, graphQlContext); - value = DataFetcherResult.newResult() - .data(adapted) - .errors(dataFetcherResult.getErrors()) - .localContext(dataFetcherResult.getLocalContext()).build(); + value = dataFetcherResult.map(data -> updateValue(data, snapshot, graphQlContext)); } else { value = updateValue(value, snapshot, graphQlContext); diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java index 9424d6747..abd3671d6 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java @@ -19,6 +19,7 @@ import java.time.Duration; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; @@ -366,4 +367,20 @@ void cancelFluxDataFetcherSubscriptionWhenRequestCancelled() throws Exception { assertThat(dataFetcherCancelled).isTrue(); } + @Test + public void testExtensionsAreRetained() throws Exception { + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .queryFetcher("greeting", (env) -> + DataFetcherResult.newResult().data("Hello") + .extensions(Map.of("foo", "bar")).build()) + .toGraphQl(); + + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build(); + ExecutionResult executionResult = graphQl.executeAsync(input).get(); + + String greeting = ResponseHelper.forResult(executionResult).toEntity("greeting", String.class); + assertThat(greeting).isEqualTo("Hello"); + + assertThat(executionResult.getExtensions()).containsEntry("foo", "bar"); + } }