Skip to content

Commit e8d8b8d

Browse files
authored
Infer basic json schema (#28)
* Infer basic json schema on primitive types
1 parent 085ed63 commit e8d8b8d

File tree

7 files changed

+111
-24
lines changed

7 files changed

+111
-24
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from greeter import greeter
1616
from virtual_object import counter
1717
from workflow import payment
18+
from pydantic_greeter import pydantic_greeter
1819

1920
import restate
2021

21-
app = restate.app(services=[greeter, counter, payment])
22+
app = restate.app(services=[greeter, counter, payment, pydantic_greeter])

examples/pydantic_greeter.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""greeter.py"""
12+
# pylint: disable=C0116
13+
# pylint: disable=W0613
14+
# pylint: disable=C0115
15+
# pylint: disable=R0903
16+
17+
from pydantic import BaseModel
18+
from restate import Service, Context
19+
20+
# models
21+
class GreetingRequest(BaseModel):
22+
name: str
23+
24+
class Greeting(BaseModel):
25+
message: str
26+
27+
# service
28+
29+
pydantic_greeter = Service("pydantic_greeter")
30+
31+
@pydantic_greeter.handler()
32+
async def greet(ctx: Context, req: GreetingRequest) -> Greeting:
33+
return Greeting(message=f"Hello {req.name}!")

examples/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
hypercorn
2-
restate_sdk
2+
restate_sdk
3+
pydantic

python/restate/discovery.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
import json
2626
import typing
2727
from enum import Enum
28-
from typing import Optional, Any, List
28+
from typing import Optional, Any, List, get_args, get_origin
2929

3030

3131
from restate.endpoint import Endpoint as RestateEndpoint
32+
from restate.handler import TypeHint
3233

3334
class ProtocolMode(Enum):
3435
BIDI_STREAM = "BIDI_STREAM"
@@ -99,6 +100,49 @@ def default(self, o):
99100
return o.value
100101
return {key: value for key, value in o.__dict__.items() if value is not None}
101102

103+
104+
# pylint: disable=R0911
105+
def type_hint_to_json_schema(type_hint: Any) -> Any:
106+
"""
107+
Convert a Python type hint to a JSON schema.
108+
109+
"""
110+
origin = get_origin(type_hint) or type_hint
111+
args = get_args(type_hint)
112+
if origin is str:
113+
return {"type": "string"}
114+
if origin is int:
115+
return {"type": "integer"}
116+
if origin is float:
117+
return {"type": "number"}
118+
if origin is bool:
119+
return {"type": "boolean"}
120+
if origin is list:
121+
items = type_hint_to_json_schema(args[0] if args else Any)
122+
return {"type": "array", "items": items}
123+
if origin is dict:
124+
return {
125+
"type": "object"
126+
}
127+
if origin is None:
128+
return {"type": "null"}
129+
# Default to all valid schema
130+
return True
131+
132+
def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any:
133+
"""
134+
Convert a type hint to a JSON schema.
135+
"""
136+
if not type_hint:
137+
return None
138+
if not type_hint.annotation:
139+
return None
140+
if type_hint.is_pydantic:
141+
return type_hint.annotation.model_json_schema(mode='serialization') # type: ignore
142+
return type_hint_to_json_schema(type_hint.annotation)
143+
144+
145+
102146
def compute_discovery_json(endpoint: RestateEndpoint,
103147
version: int,
104148
discovered_as: typing.Literal["bidi", "request_response"]) -> typing.Tuple[typing.Dict[str, str] ,str]:
@@ -113,13 +157,6 @@ def compute_discovery_json(endpoint: RestateEndpoint,
113157
headers = {"content-type": "application/vnd.restate.endpointmanifest.v1+json"}
114158
return (headers, json_str)
115159

116-
def try_extract_json_schema(model: Any) -> typing.Optional[typing.Any]:
117-
"""
118-
Try to extract the JSON schema from a schema object
119-
"""
120-
if model:
121-
return model.model_json_schema(mode='serialization')
122-
return None
123160

124161
def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal["bidi", "request_response"]) -> Endpoint:
125162
"""
@@ -139,11 +176,11 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
139176
# input
140177
inp = InputPayload(required=False,
141178
contentType=handler.handler_io.accept,
142-
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_input_model))
179+
jsonSchema=json_schema_from_type_hint(handler.handler_io.input_type))
143180
# output
144181
out = OutputPayload(setContentTypeIfEmpty=False,
145182
contentType=handler.handler_io.content_type,
146-
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_output_model))
183+
jsonSchema=json_schema_from_type_hint(handler.handler_io.output_type))
147184
# add the handler
148185
service_handlers.append(Handler(name=handler.name, ty=ty, input=inp, output=out))
149186

python/restate/handler.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
I = TypeVar('I')
2626
O = TypeVar('O')
27+
T = TypeVar('T')
2728

2829
# we will use this symbol to store the handler in the function
2930
RESTATE_UNIQUE_HANDLER_SYMBOL = str(object())
@@ -42,7 +43,7 @@ class Dummy: # pylint: disable=too-few-public-methods
4243

4344
return Dummy
4445

45-
PYDANTIC_BASE_MODEL = try_import_pydantic_base_model()
46+
PydanticBaseModel = try_import_pydantic_base_model()
4647

4748
@dataclass
4849
class ServiceTag:
@@ -52,6 +53,14 @@ class ServiceTag:
5253
kind: Literal["object", "service", "workflow"]
5354
name: str
5455

56+
@dataclass
57+
class TypeHint(Generic[T]):
58+
"""
59+
Represents a type hint.
60+
"""
61+
annotation: Optional[T] = None
62+
is_pydantic: bool = False
63+
5564
@dataclass
5665
class HandlerIO(Generic[I, O]):
5766
"""
@@ -65,38 +74,43 @@ class HandlerIO(Generic[I, O]):
6574
content_type: str
6675
input_serde: Serde[I]
6776
output_serde: Serde[O]
68-
pydantic_input_model: Optional[I] = None
69-
pydantic_output_model: Optional[O] = None
77+
input_type: Optional[TypeHint[I]] = None
78+
output_type: Optional[TypeHint[O]] = None
7079

7180
def is_pydantic(annotation) -> bool:
7281
"""
7382
Check if an object is a Pydantic model.
7483
"""
7584
try:
76-
return issubclass(annotation, PYDANTIC_BASE_MODEL)
85+
return issubclass(annotation, PydanticBaseModel)
7786
except TypeError:
7887
# annotation is not a class or a type
7988
return False
8089

8190

82-
def infer_pydantic_io(handler_io: HandlerIO[I, O], signature: Signature):
91+
def extract_io_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
8392
"""
84-
Augment handler_io with Pydantic models when these are provided.
93+
Augment handler_io with additional information about the input and output types.
94+
95+
This function has a special check for Pydantic models when these are provided.
8596
This method will inspect the signature of an handler and will look for
8697
the input and the return types of a function, and will:
8798
* capture any Pydantic models (to be used later at discovery)
8899
* replace the default json serializer (is unchanged by a user) with a Pydantic serde
89100
"""
90-
# check if the handlers I/O is a PydanticBaseModel
91101
annotation = list(signature.parameters.values())[-1].annotation
102+
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False)
103+
92104
if is_pydantic(annotation):
93-
handler_io.pydantic_input_model = annotation
105+
handler_io.input_type.is_pydantic = True
94106
if isinstance(handler_io.input_serde, JsonSerde): # type: ignore
95107
handler_io.input_serde = PydanticJsonSerde(annotation)
96108

97109
annotation = signature.return_annotation
110+
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False)
111+
98112
if is_pydantic(annotation):
99-
handler_io.pydantic_output_model = annotation
113+
handler_io.output_type.is_pydantic=True
100114
if isinstance(handler_io.output_serde, JsonSerde): # type: ignore
101115
handler_io.output_serde = PydanticJsonSerde(annotation)
102116

@@ -136,7 +150,7 @@ def make_handler(service_tag: ServiceTag,
136150
raise ValueError("Handler must have at least one parameter")
137151

138152
arity = len(signature.parameters)
139-
infer_pydantic_io(handler_io, signature)
153+
extract_io_type_hints(handler_io, signature)
140154

141155
handler = Handler[I, O](service_tag,
142156
handler_io,

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ mypy
22
pylint
33
hypercorn
44
maturin
5-
pytest
5+
pytest
6+
pydantic

0 commit comments

Comments
 (0)