Skip to content

Commit 43e274e

Browse files
committed
Use latest feature of GraphQL-core
1 parent ecf62eb commit 43e274e

File tree

8 files changed

+91
-300
lines changed

8 files changed

+91
-300
lines changed

strawberry/http/__init__.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from __future__ import annotations
22

3-
import json
43
from dataclasses import dataclass
54
from typing import TYPE_CHECKING, Any, Optional
65
from typing_extensions import Literal, TypedDict
76

87
if TYPE_CHECKING:
9-
from collections.abc import Mapping
10-
118
from strawberry.types import ExecutionResult
129

1310

@@ -39,22 +36,6 @@ class GraphQLRequestData:
3936
protocol: Literal["http", "multipart-subscription"] = "http"
4037

4138

42-
def parse_query_params(params: dict[str, str]) -> dict[str, Any]:
43-
if "variables" in params:
44-
params["variables"] = json.loads(params["variables"])
45-
46-
return params
47-
48-
49-
def parse_request_data(data: Mapping[str, Any]) -> GraphQLRequestData:
50-
return GraphQLRequestData(
51-
query=data.get("query"),
52-
variables=data.get("variables"),
53-
operation_name=data.get("operationName"),
54-
extensions=data.get("extensions"),
55-
)
56-
57-
5839
__all__ = [
5940
"GraphQLHTTPResponse",
6041
"GraphQLRequestData",

strawberry/http/async_base_view.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
)
3434
from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
3535
from strawberry.types import ExecutionResult, SubscriptionExecutionResult
36-
from strawberry.types.context_wrapper import ContextWrapper
3736
from strawberry.types.graphql import OperationType
3837
from strawberry.types.unset import UNSET, UnsetType
3938

@@ -198,26 +197,24 @@ async def execute_operation(
198197

199198
assert self.schema
200199

201-
context_wrapper = ContextWrapper(
202-
context=context, extensions=request_data.extensions
203-
)
204-
205200
if request_data.protocol == "multipart-subscription":
206201
return await self.schema.subscribe(
207202
request_data.query, # type: ignore
208203
variable_values=request_data.variables,
209204
context_value=context,
210205
root_value=root_value,
211206
operation_name=request_data.operation_name,
207+
operation_extensions=request_data.extensions,
212208
)
213209

214210
return await self.schema.execute(
215211
request_data.query,
216212
root_value=root_value,
217213
variable_values=request_data.variables,
218-
context_value=context_wrapper,
214+
context_value=context,
219215
operation_name=request_data.operation_name,
220216
allowed_operation_types=allowed_operation_types,
217+
operation_extensions=request_data.extensions,
221218
)
222219

223220
async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> dict[str, str]:

strawberry/http/sync_base_view.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from strawberry.schema import BaseSchema
2424
from strawberry.schema.exceptions import InvalidOperationTypeError
2525
from strawberry.types import ExecutionResult
26-
from strawberry.types.context_wrapper import ContextWrapper
2726
from strawberry.types.graphql import OperationType
2827

2928
from .base import BaseView
@@ -116,17 +115,14 @@ def execute_operation(
116115

117116
assert self.schema
118117

119-
context_wrapper = ContextWrapper(
120-
context=context, extensions=request_data.extensions
121-
)
122-
123118
return self.schema.execute_sync(
124119
request_data.query,
125120
root_value=root_value,
126121
variable_values=request_data.variables,
127-
context_value=context_wrapper,
122+
context_value=context,
128123
operation_name=request_data.operation_name,
129124
allowed_operation_types=allowed_operation_types,
125+
operation_extensions=request_data.extensions,
130126
)
131127

132128
def parse_multipart(self, request: SyncHTTPRequestAdapter) -> dict[str, str]:

strawberry/schema/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ async def execute(
4747
root_value: Optional[Any] = None,
4848
operation_name: Optional[str] = None,
4949
allowed_operation_types: Optional[Iterable[OperationType]] = None,
50+
operation_extensions: Optional[dict[str, Any]] = None,
5051
) -> ExecutionResult:
5152
raise NotImplementedError
5253

@@ -59,6 +60,7 @@ def execute_sync(
5960
root_value: Optional[Any] = None,
6061
operation_name: Optional[str] = None,
6162
allowed_operation_types: Optional[Iterable[OperationType]] = None,
63+
operation_extensions: Optional[dict[str, Any]] = None,
6264
) -> ExecutionResult:
6365
raise NotImplementedError
6466

@@ -70,6 +72,7 @@ async def subscribe(
7072
context_value: Optional[Any] = None,
7173
root_value: Optional[Any] = None,
7274
operation_name: Optional[str] = None,
75+
operation_extensions: Optional[dict[str, Any]] = None,
7376
) -> SubscriptionResult:
7477
raise NotImplementedError
7578

strawberry/schema/schema.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,33 @@
99
TYPE_CHECKING,
1010
Any,
1111
Callable,
12+
NamedTuple,
1213
Optional,
1314
Union,
1415
cast,
1516
)
1617

18+
from graphql import ExecutionContext as GraphQLExecutionContext
1719
from graphql import ExecutionResult as GraphQLExecutionResult
1820
from graphql import (
1921
ExecutionResult as OriginalExecutionResult,
2022
)
2123
from graphql import (
24+
FieldNode,
25+
FragmentDefinitionNode,
2226
GraphQLBoolean,
2327
GraphQLError,
2428
GraphQLField,
2529
GraphQLNamedType,
2630
GraphQLNonNull,
31+
GraphQLObjectType,
32+
GraphQLOutputType,
2733
GraphQLSchema,
34+
OperationDefinitionNode,
2835
get_introspection_query,
2936
parse,
3037
validate_schema,
3138
)
32-
from graphql.execution import ExecutionContext as GraphQLExecutionContext
3339
from graphql.execution import execute, subscribe
3440
from graphql.execution.middleware import MiddlewareManager
3541
from graphql.type.directives import specified_directives
@@ -58,7 +64,7 @@
5864
PreExecutionError,
5965
)
6066
from strawberry.types.graphql import OperationType
61-
from strawberry.utils import IS_GQL_32
67+
from strawberry.utils import IS_GQL_32, IS_GQL_33
6268
from strawberry.utils.aio import aclosing
6369
from strawberry.utils.await_maybe import await_maybe
6470

@@ -71,8 +77,10 @@
7177
from collections.abc import Iterable, Mapping
7278
from typing_extensions import TypeAlias
7379

74-
from graphql import ExecutionContext as GraphQLExecutionContext
80+
from graphql.execution.collect_fields import FieldGroup
7581
from graphql.language import DocumentNode
82+
from graphql.pyutils import Path
83+
from graphql.type import GraphQLResolveInfo
7684
from graphql.validation import ASTValidationRule
7785

7886
from strawberry.directive import StrawberryDirective
@@ -136,6 +144,54 @@ def _coerce_error(error: Union[GraphQLError, Exception]) -> GraphQLError:
136144
return GraphQLError(str(error), original_error=error)
137145

138146

147+
class OperationContextAwareGraphQLResolveInfo(NamedTuple): # pyright: ignore
148+
field_name: str
149+
field_nodes: list[FieldNode]
150+
return_type: GraphQLOutputType
151+
parent_type: GraphQLObjectType
152+
path: Path
153+
schema: GraphQLSchema
154+
fragments: dict[str, FragmentDefinitionNode]
155+
root_value: Any
156+
operation: OperationDefinitionNode
157+
variable_values: dict[str, Any]
158+
context: Any
159+
is_awaitable: Callable[[Any], bool]
160+
operation_extensions: dict[str, Any]
161+
162+
163+
class OperationContextAwareGraphQLExecutionContext(GraphQLExecutionContext):
164+
def __init__(self, *args: Any, **kwargs: Any) -> None:
165+
operation_extensions = kwargs.pop("operation_extensions", None)
166+
167+
super().__init__(*args, **kwargs)
168+
169+
self.operation_extensions = operation_extensions
170+
171+
def build_resolve_info(
172+
self,
173+
field_def: GraphQLField,
174+
field_group: FieldGroup,
175+
parent_type: GraphQLObjectType,
176+
path: Path,
177+
) -> GraphQLResolveInfo:
178+
return OperationContextAwareGraphQLResolveInfo(
179+
field_group.fields[0].node.name.value,
180+
field_group.to_nodes(),
181+
field_def.type,
182+
parent_type,
183+
path,
184+
self.schema,
185+
self.fragments,
186+
self.root_value,
187+
self.operation,
188+
self.variable_values,
189+
self.context_value,
190+
self.is_awaitable,
191+
self.operation_extensions,
192+
)
193+
194+
139195
class Schema(BaseSchema):
140196
def __init__(
141197
self,
@@ -195,7 +251,9 @@ class Query:
195251

196252
self.extensions = extensions
197253
self._cached_middleware_manager: MiddlewareManager | None = None
198-
self.execution_context_class = execution_context_class
254+
self.execution_context_class = (
255+
execution_context_class or OperationContextAwareGraphQLExecutionContext
256+
)
199257
self.config = config or StrawberryConfig()
200258

201259
self.schema_converter = GraphQLCoreConverter(
@@ -320,6 +378,14 @@ def create_extensions_runner(
320378
extensions=extensions,
321379
)
322380

381+
def _get_custom_context_kwargs(
382+
self, operation_extensions: Optional[dict[str, Any]] = None
383+
) -> dict[str, Any]:
384+
if not IS_GQL_33:
385+
return {}
386+
387+
return {"operation_extensions": operation_extensions}
388+
323389
def _get_middleware_manager(
324390
self, extensions: list[SchemaExtension]
325391
) -> MiddlewareManager:
@@ -463,6 +529,7 @@ async def execute(
463529
root_value: Optional[Any] = None,
464530
operation_name: Optional[str] = None,
465531
allowed_operation_types: Optional[Iterable[OperationType]] = None,
532+
operation_extensions: Optional[dict[str, Any]] = None,
466533
) -> ExecutionResult:
467534
if allowed_operation_types is None:
468535
allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
@@ -483,6 +550,8 @@ async def execute(
483550
extensions_runner = self.create_extensions_runner(execution_context, extensions)
484551
middleware_manager = self._get_middleware_manager(extensions)
485552

553+
custom_context_kwargs = self._get_custom_context_kwargs(operation_extensions)
554+
486555
try:
487556
async with extensions_runner.operation():
488557
# Note: In graphql-core the schema would be validated here but in
@@ -510,6 +579,7 @@ async def execute(
510579
operation_name=execution_context.operation_name,
511580
context_value=execution_context.context,
512581
execution_context_class=self.execution_context_class,
582+
**custom_context_kwargs,
513583
)
514584
)
515585
execution_context.result = result
@@ -547,6 +617,7 @@ def execute_sync(
547617
root_value: Optional[Any] = None,
548618
operation_name: Optional[str] = None,
549619
allowed_operation_types: Optional[Iterable[OperationType]] = None,
620+
operation_extensions: Optional[dict[str, Any]] = None,
550621
) -> ExecutionResult:
551622
if allowed_operation_types is None:
552623
allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
@@ -567,6 +638,8 @@ def execute_sync(
567638
extensions_runner = self.create_extensions_runner(execution_context, extensions)
568639
middleware_manager = self._get_middleware_manager(extensions)
569640

641+
custom_context_kwargs = self._get_custom_context_kwargs(operation_extensions)
642+
570643
try:
571644
with extensions_runner.operation():
572645
# Note: In graphql-core the schema would be validated here but in
@@ -617,6 +690,7 @@ def execute_sync(
617690
operation_name=execution_context.operation_name,
618691
context_value=execution_context.context,
619692
execution_context_class=self.execution_context_class,
693+
**custom_context_kwargs,
620694
)
621695

622696
if isawaitable(result):
@@ -661,6 +735,7 @@ async def _subscribe(
661735
extensions_runner: SchemaExtensionsRunner,
662736
middleware_manager: MiddlewareManager,
663737
execution_context_class: type[GraphQLExecutionContext] | None = None,
738+
operation_extensions: Optional[dict[str, Any]] = None,
664739
) -> AsyncGenerator[ExecutionResult, None]:
665740
async with extensions_runner.operation():
666741
if initial_error := await self._parse_and_validate_async(
@@ -679,6 +754,7 @@ async def _subscribe(
679754
gql_33_kwargs = {
680755
"middleware": middleware_manager,
681756
"execution_context_class": execution_context_class,
757+
"operation_extensions": operation_extensions,
682758
}
683759
try:
684760
# Might not be awaitable for pre-execution errors.
@@ -743,6 +819,7 @@ async def subscribe(
743819
context_value: Optional[Any] = None,
744820
root_value: Optional[Any] = None,
745821
operation_name: Optional[str] = None,
822+
operation_extensions: Optional[dict[str, Any]] = None,
746823
) -> SubscriptionResult:
747824
execution_context = self._create_execution_context(
748825
query=query,
@@ -764,6 +841,7 @@ async def subscribe(
764841
),
765842
middleware_manager=self._get_middleware_manager(extensions),
766843
execution_context_class=self.execution_context_class,
844+
operation_extensions=operation_extensions,
767845
)
768846

769847
def _resolve_node_ids(self) -> None:

strawberry/types/context_wrapper.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

strawberry/types/info.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
)
1313
from typing_extensions import TypeVar
1414

15-
from .context_wrapper import ContextWrapper
1615
from .nodes import convert_selections
1716

1817
if TYPE_CHECKING:
@@ -112,18 +111,12 @@ def selected_fields(self) -> list[Selection]:
112111
@property
113112
def context(self) -> ContextType:
114113
"""The context passed to the query execution."""
115-
if isinstance(self._raw_info.context, ContextWrapper):
116-
return self._raw_info.context.context
117-
118114
return self._raw_info.context
119115

120116
@property
121117
def input_extensions(self) -> dict[str, Any]:
122118
"""The input extensions passed to the query execution."""
123-
if isinstance(self._raw_info.context, ContextWrapper):
124-
return self._raw_info.context.extensions
125-
126-
return {}
119+
return self._raw_info.operation_extensions
127120

128121
@property
129122
def root_value(self) -> RootValueType:

0 commit comments

Comments
 (0)