Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/sphinx/serializers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Serializers
.. autoclass:: JsonSerializer
:members:

.. autoclass:: OrjsonSerializer
:members:

.. autoclass:: TextSerializer
:members:

Expand Down
7 changes: 7 additions & 0 deletions elastic_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@
"Urllib3HttpNode",
]

try:
from elastic_transport._serializer import OrjsonSerializer # noqa: F401

__all__.append("OrjsonSerializer")
except ModuleNotFoundError:
pass

_logger = logging.getLogger("elastic_transport")
_logger.addHandler(logging.NullHandler())
del _logger
Expand Down
35 changes: 33 additions & 2 deletions elastic_transport/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,15 @@

from ._exceptions import SerializationError

try:
import orjson
except ModuleNotFoundError:
orjson = None # type: ignore[assignment]


class Serializer:
"""Serializer interface."""

mimetype: ClassVar[str]

def loads(self, data: bytes) -> Any: # pragma: nocover
Expand All @@ -36,6 +43,8 @@ def dumps(self, data: Any) -> bytes: # pragma: nocover


class TextSerializer(Serializer):
"""Text serializer to and from UTF-8."""

mimetype: ClassVar[str] = "text/*"

def loads(self, data: bytes) -> str:
Expand All @@ -62,6 +71,8 @@ def dumps(self, data: str) -> bytes:


class JsonSerializer(Serializer):
"""JSON serializer relying on the standard library json module."""

mimetype: ClassVar[str] = "application/json"

def default(self, data: Any) -> Any:
Expand All @@ -81,14 +92,15 @@ def json_dumps(self, data: Any) -> bytes:
).encode("utf-8", "surrogatepass")

def json_loads(self, data: bytes) -> Any:
return json.loads(data)

def loads(self, data: bytes) -> Any:
# Sometimes responses use Content-Type: json but actually
# don't contain any data. We should return something instead
# of erroring in these cases.
if data == b"":
return None
return json.loads(data)

def loads(self, data: bytes) -> Any:
try:
return self.json_loads(data)
except (ValueError, TypeError) as e:
Expand All @@ -115,7 +127,26 @@ def dumps(self, data: Any) -> bytes:
)


if orjson is not None:

class OrjsonSerializer(JsonSerializer):
"""JSON serializer relying on the orjson package.

Only available if orjson if installed. It is faster, especially for vectors, but is also stricter.
"""

def json_dumps(self, data: Any) -> bytes:
return orjson.dumps(
data, default=self.default, option=orjson.OPT_SERIALIZE_NUMPY
)

def json_loads(self, data: bytes) -> Any:
return orjson.loads(data)


class NdjsonSerializer(JsonSerializer):
"""Newline delimited JSON (NDJSON) serializer relying on the standard library json module."""

mimetype: ClassVar[str] = "application/x-ndjson"

def loads(self, data: bytes) -> Any:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"respx",
"opentelemetry-api",
"opentelemetry-sdk",
"orjson",
# Override Read the Docs default (sphinx<2)
"sphinx>2",
"furo",
Expand Down
7 changes: 5 additions & 2 deletions tests/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@

@modules
def test__all__sorted(module):
print(sorted(module.__all__))
assert module.__all__ == sorted(module.__all__)
module_all = module.__all__.copy()
# Optional dependencies are added at the end
if "OrjsonSerializer" in module_all:
module_all.remove("OrjsonSerializer")
assert module_all == sorted(module_all)


@modules
Expand Down
79 changes: 50 additions & 29 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from elastic_transport import (
JsonSerializer,
NdjsonSerializer,
OrjsonSerializer,
SerializationError,
SerializerCollection,
TextSerializer,
Expand All @@ -33,66 +34,86 @@
serializers = SerializerCollection(DEFAULT_SERIALIZERS)


def test_date_serialization():
assert b'{"d":"2010-10-01"}' == JsonSerializer().dumps({"d": date(2010, 10, 1)})
@pytest.fixture(params=[JsonSerializer, OrjsonSerializer])
def json_serializer(request: pytest.FixtureRequest):
yield request.param()


def test_decimal_serialization():
assert b'{"d":3.8}' == JsonSerializer().dumps({"d": Decimal("3.8")})
def test_date_serialization(json_serializer):
assert b'{"d":"2010-10-01"}' == json_serializer.dumps({"d": date(2010, 10, 1)})


def test_uuid_serialization():
assert b'{"d":"00000000-0000-0000-0000-000000000003"}' == JsonSerializer().dumps(
def test_decimal_serialization(json_serializer):
assert b'{"d":3.8}' == json_serializer.dumps({"d": Decimal("3.8")})


def test_uuid_serialization(json_serializer):
assert b'{"d":"00000000-0000-0000-0000-000000000003"}' == json_serializer.dumps(
{"d": uuid.UUID("00000000-0000-0000-0000-000000000003")}
)


def test_serializes_nan():
assert b'{"d":NaN}' == JsonSerializer().dumps({"d": float("NaN")})
# NaN is invalid JSON, and orjson silently converts it to null
assert b'{"d":null}' == OrjsonSerializer().dumps({"d": float("NaN")})


def test_raises_serialization_error_on_dump_error():
def test_raises_serialization_error_on_dump_error(json_serializer):
with pytest.raises(SerializationError):
JsonSerializer().dumps(object())
json_serializer.dumps(object())
with pytest.raises(SerializationError):
TextSerializer().dumps({})


def test_raises_serialization_error_on_load_error():
def test_raises_serialization_error_on_load_error(json_serializer):
with pytest.raises(SerializationError):
JsonSerializer().loads(object())
json_serializer.loads(object())
with pytest.raises(SerializationError):
JsonSerializer().loads(b"{{")
json_serializer.loads(b"{{")


def test_unicode_is_handled():
j = JsonSerializer()
def test_json_unicode_is_handled(json_serializer):
assert (
j.dumps({"你好": "你好"})
json_serializer.dumps({"你好": "你好"})
== b'{"\xe4\xbd\xa0\xe5\xa5\xbd":"\xe4\xbd\xa0\xe5\xa5\xbd"}'
)
assert j.loads(b'{"\xe4\xbd\xa0\xe5\xa5\xbd":"\xe4\xbd\xa0\xe5\xa5\xbd"}') == {
"你好": "你好"
}
assert json_serializer.loads(
b'{"\xe4\xbd\xa0\xe5\xa5\xbd":"\xe4\xbd\xa0\xe5\xa5\xbd"}'
) == {"你好": "你好"}


t = TextSerializer()
assert t.dumps("你好") == b"\xe4\xbd\xa0\xe5\xa5\xbd"
assert t.loads(b"\xe4\xbd\xa0\xe5\xa5\xbd") == "你好"
def test_text_unicode_is_handled():
text_serializer = TextSerializer()
assert text_serializer.dumps("你好") == b"\xe4\xbd\xa0\xe5\xa5\xbd"
assert text_serializer.loads(b"\xe4\xbd\xa0\xe5\xa5\xbd") == "你好"


def test_unicode_surrogates_handled():
j = JsonSerializer()
def test_json_unicode_surrogates_handled():
assert (
j.dumps({"key": "你好\uda6a"})
JsonSerializer().dumps({"key": "你好\uda6a"})
== b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}'
)
assert j.loads(b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}') == {
"key": "你好\uda6a"
}
assert JsonSerializer().loads(
b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}'
) == {"key": "你好\uda6a"}

# orjson is strict about UTF-8
with pytest.raises(SerializationError):
OrjsonSerializer().dumps({"key": "你好\uda6a"})

with pytest.raises(SerializationError):
OrjsonSerializer().loads(b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"}')


t = TextSerializer()
assert t.dumps("你好\uda6a") == b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
assert t.loads(b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa") == "你好\uda6a"
def test_text_unicode_surrogates_handled(json_serializer):
text_serializer = TextSerializer()
assert (
text_serializer.dumps("你好\uda6a") == b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
)
assert (
text_serializer.loads(b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa") == "你好\uda6a"
)


def test_deserializes_json_by_default():
Expand Down