9
9
TYPE_CHECKING ,
10
10
Any ,
11
11
Callable ,
12
+ NamedTuple ,
12
13
Optional ,
13
14
Union ,
14
15
cast ,
15
16
)
16
17
18
+ from graphql import ExecutionContext as GraphQLExecutionContext
17
19
from graphql import ExecutionResult as GraphQLExecutionResult
18
20
from graphql import (
19
21
ExecutionResult as OriginalExecutionResult ,
20
22
)
21
23
from graphql import (
24
+ FieldNode ,
25
+ FragmentDefinitionNode ,
22
26
GraphQLBoolean ,
23
27
GraphQLError ,
24
28
GraphQLField ,
25
29
GraphQLNamedType ,
26
30
GraphQLNonNull ,
31
+ GraphQLObjectType ,
32
+ GraphQLOutputType ,
27
33
GraphQLSchema ,
34
+ OperationDefinitionNode ,
28
35
get_introspection_query ,
29
36
parse ,
30
37
validate_schema ,
31
38
)
32
- from graphql .execution import ExecutionContext as GraphQLExecutionContext
33
39
from graphql .execution import execute , subscribe
34
40
from graphql .execution .middleware import MiddlewareManager
35
41
from graphql .type .directives import specified_directives
58
64
PreExecutionError ,
59
65
)
60
66
from strawberry .types .graphql import OperationType
61
- from strawberry .utils import IS_GQL_32
67
+ from strawberry .utils import IS_GQL_32 , IS_GQL_33
62
68
from strawberry .utils .aio import aclosing
63
69
from strawberry .utils .await_maybe import await_maybe
64
70
71
77
from collections .abc import Iterable , Mapping
72
78
from typing_extensions import TypeAlias
73
79
74
- from graphql import ExecutionContext as GraphQLExecutionContext
80
+ from graphql . execution . collect_fields import FieldGroup
75
81
from graphql .language import DocumentNode
82
+ from graphql .pyutils import Path
83
+ from graphql .type import GraphQLResolveInfo
76
84
from graphql .validation import ASTValidationRule
77
85
78
86
from strawberry .directive import StrawberryDirective
@@ -136,6 +144,54 @@ def _coerce_error(error: Union[GraphQLError, Exception]) -> GraphQLError:
136
144
return GraphQLError (str (error ), original_error = error )
137
145
138
146
147
+ class OperationContextAwareGraphQLResolveInfo (NamedTuple ): # pyright: ignore
148
+ field_name : str
149
+ field_nodes : list [FieldNode ]
150
+ return_type : GraphQLOutputType
151
+ parent_type : GraphQLObjectType
152
+ path : Path
153
+ schema : GraphQLSchema
154
+ fragments : dict [str , FragmentDefinitionNode ]
155
+ root_value : Any
156
+ operation : OperationDefinitionNode
157
+ variable_values : dict [str , Any ]
158
+ context : Any
159
+ is_awaitable : Callable [[Any ], bool ]
160
+ operation_extensions : dict [str , Any ]
161
+
162
+
163
+ class OperationContextAwareGraphQLExecutionContext (GraphQLExecutionContext ):
164
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
165
+ operation_extensions = kwargs .pop ("operation_extensions" , None )
166
+
167
+ super ().__init__ (* args , ** kwargs )
168
+
169
+ self .operation_extensions = operation_extensions
170
+
171
+ def build_resolve_info (
172
+ self ,
173
+ field_def : GraphQLField ,
174
+ field_group : FieldGroup ,
175
+ parent_type : GraphQLObjectType ,
176
+ path : Path ,
177
+ ) -> GraphQLResolveInfo :
178
+ return OperationContextAwareGraphQLResolveInfo (
179
+ field_group .fields [0 ].node .name .value ,
180
+ field_group .to_nodes (),
181
+ field_def .type ,
182
+ parent_type ,
183
+ path ,
184
+ self .schema ,
185
+ self .fragments ,
186
+ self .root_value ,
187
+ self .operation ,
188
+ self .variable_values ,
189
+ self .context_value ,
190
+ self .is_awaitable ,
191
+ self .operation_extensions ,
192
+ )
193
+
194
+
139
195
class Schema (BaseSchema ):
140
196
def __init__ (
141
197
self ,
@@ -195,7 +251,9 @@ class Query:
195
251
196
252
self .extensions = extensions
197
253
self ._cached_middleware_manager : MiddlewareManager | None = None
198
- self .execution_context_class = execution_context_class
254
+ self .execution_context_class = (
255
+ execution_context_class or OperationContextAwareGraphQLExecutionContext
256
+ )
199
257
self .config = config or StrawberryConfig ()
200
258
201
259
self .schema_converter = GraphQLCoreConverter (
@@ -320,6 +378,14 @@ def create_extensions_runner(
320
378
extensions = extensions ,
321
379
)
322
380
381
+ def _get_custom_context_kwargs (
382
+ self , operation_extensions : Optional [dict [str , Any ]] = None
383
+ ) -> dict [str , Any ]:
384
+ if not IS_GQL_33 :
385
+ return {}
386
+
387
+ return {"operation_extensions" : operation_extensions }
388
+
323
389
def _get_middleware_manager (
324
390
self , extensions : list [SchemaExtension ]
325
391
) -> MiddlewareManager :
@@ -463,6 +529,7 @@ async def execute(
463
529
root_value : Optional [Any ] = None ,
464
530
operation_name : Optional [str ] = None ,
465
531
allowed_operation_types : Optional [Iterable [OperationType ]] = None ,
532
+ operation_extensions : Optional [dict [str , Any ]] = None ,
466
533
) -> ExecutionResult :
467
534
if allowed_operation_types is None :
468
535
allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
@@ -483,6 +550,8 @@ async def execute(
483
550
extensions_runner = self .create_extensions_runner (execution_context , extensions )
484
551
middleware_manager = self ._get_middleware_manager (extensions )
485
552
553
+ custom_context_kwargs = self ._get_custom_context_kwargs (operation_extensions )
554
+
486
555
try :
487
556
async with extensions_runner .operation ():
488
557
# Note: In graphql-core the schema would be validated here but in
@@ -510,6 +579,7 @@ async def execute(
510
579
operation_name = execution_context .operation_name ,
511
580
context_value = execution_context .context ,
512
581
execution_context_class = self .execution_context_class ,
582
+ ** custom_context_kwargs ,
513
583
)
514
584
)
515
585
execution_context .result = result
@@ -547,6 +617,7 @@ def execute_sync(
547
617
root_value : Optional [Any ] = None ,
548
618
operation_name : Optional [str ] = None ,
549
619
allowed_operation_types : Optional [Iterable [OperationType ]] = None ,
620
+ operation_extensions : Optional [dict [str , Any ]] = None ,
550
621
) -> ExecutionResult :
551
622
if allowed_operation_types is None :
552
623
allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES
@@ -567,6 +638,8 @@ def execute_sync(
567
638
extensions_runner = self .create_extensions_runner (execution_context , extensions )
568
639
middleware_manager = self ._get_middleware_manager (extensions )
569
640
641
+ custom_context_kwargs = self ._get_custom_context_kwargs (operation_extensions )
642
+
570
643
try :
571
644
with extensions_runner .operation ():
572
645
# Note: In graphql-core the schema would be validated here but in
@@ -617,6 +690,7 @@ def execute_sync(
617
690
operation_name = execution_context .operation_name ,
618
691
context_value = execution_context .context ,
619
692
execution_context_class = self .execution_context_class ,
693
+ ** custom_context_kwargs ,
620
694
)
621
695
622
696
if isawaitable (result ):
@@ -661,6 +735,7 @@ async def _subscribe(
661
735
extensions_runner : SchemaExtensionsRunner ,
662
736
middleware_manager : MiddlewareManager ,
663
737
execution_context_class : type [GraphQLExecutionContext ] | None = None ,
738
+ operation_extensions : Optional [dict [str , Any ]] = None ,
664
739
) -> AsyncGenerator [ExecutionResult , None ]:
665
740
async with extensions_runner .operation ():
666
741
if initial_error := await self ._parse_and_validate_async (
@@ -679,6 +754,7 @@ async def _subscribe(
679
754
gql_33_kwargs = {
680
755
"middleware" : middleware_manager ,
681
756
"execution_context_class" : execution_context_class ,
757
+ "operation_extensions" : operation_extensions ,
682
758
}
683
759
try :
684
760
# Might not be awaitable for pre-execution errors.
@@ -743,6 +819,7 @@ async def subscribe(
743
819
context_value : Optional [Any ] = None ,
744
820
root_value : Optional [Any ] = None ,
745
821
operation_name : Optional [str ] = None ,
822
+ operation_extensions : Optional [dict [str , Any ]] = None ,
746
823
) -> SubscriptionResult :
747
824
execution_context = self ._create_execution_context (
748
825
query = query ,
@@ -764,6 +841,7 @@ async def subscribe(
764
841
),
765
842
middleware_manager = self ._get_middleware_manager (extensions ),
766
843
execution_context_class = self .execution_context_class ,
844
+ operation_extensions = operation_extensions ,
767
845
)
768
846
769
847
def _resolve_node_ids (self ) -> None :
0 commit comments