Skip to content

Commit c2bc41b

Browse files
committed
Implement strict typing for bytes.
Fix #1641.
1 parent 2fdf352 commit c2bc41b

File tree

15 files changed

+84
-64
lines changed

15 files changed

+84
-64
lines changed

docs/reference/types.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ Types
55

66
.. autodata:: Data
77

8+
.. autodata:: BytesLike
9+
10+
.. autodata:: DataLike
11+
812
.. autodata:: LoggerLike
913

1014
.. autodata:: StatusLike

src/websockets/asyncio/connection.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
ConnectionClosedOK,
2020
ProtocolError,
2121
)
22-
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
22+
from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode
2323
from ..http11 import Request, Response
2424
from ..protocol import CLOSED, OPEN, Event, Protocol, State
25-
from ..typing import Data, LoggerLike, Subprotocol
25+
from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol
2626
from .compatibility import (
2727
TimeoutError,
2828
aiter,
@@ -402,7 +402,7 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data
402402

403403
async def send(
404404
self,
405-
message: Data | Iterable[Data] | AsyncIterable[Data],
405+
message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike],
406406
text: bool | None = None,
407407
) -> None:
408408
"""
@@ -657,7 +657,7 @@ async def wait_closed(self) -> None:
657657
"""
658658
await asyncio.shield(self.connection_lost_waiter)
659659

660-
async def ping(self, data: Data | None = None) -> Awaitable[float]:
660+
async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
661661
"""
662662
Send a Ping_.
663663
@@ -710,7 +710,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]:
710710
self.protocol.send_ping(data)
711711
return pong_waiter
712712

713-
async def pong(self, data: Data = b"") -> None:
713+
async def pong(self, data: DataLike = b"") -> None:
714714
"""
715715
Send a Pong_.
716716
@@ -1134,7 +1134,7 @@ def eof_received(self) -> None:
11341134

11351135
def broadcast(
11361136
connections: Iterable[Connection],
1137-
message: Data,
1137+
message: DataLike,
11381138
raise_exceptions: bool = False,
11391139
) -> None:
11401140
"""

src/websockets/extensions/permessage_deflate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
PayloadTooBig,
1414
ProtocolError,
1515
)
16-
from ..typing import ExtensionName, ExtensionParameter
16+
from ..typing import BytesLike, ExtensionName, ExtensionParameter
1717
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
1818

1919

@@ -129,6 +129,7 @@ def decode(
129129
# Uncompress data. Protect against zip bombs by preventing zlib from
130130
# decompressing more than max_length bytes (except when the limit is
131131
# disabled with max_size = None).
132+
data: BytesLike
132133
if frame.fin and len(frame.data) < 2044:
133134
# Profiling shows that appending four bytes, which makes a copy, is
134135
# faster than calling decompress() again when data is less than 2kB.
@@ -182,6 +183,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame:
182183
)
183184

184185
# Compress data.
186+
data: BytesLike
185187
data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
186188
if frame.fin:
187189
# Sync flush generates between 5 or 6 bytes, ending with the bytes

src/websockets/frames.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Callable
1111

1212
from .exceptions import PayloadTooBig, ProtocolError
13+
from .typing import BytesLike
1314

1415

1516
try:
@@ -118,9 +119,6 @@ class CloseCode(enum.IntEnum):
118119
}
119120

120121

121-
BytesLike = bytes, bytearray, memoryview
122-
123-
124122
@dataclasses.dataclass
125123
class Frame:
126124
"""
@@ -140,7 +138,7 @@ class Frame:
140138
"""
141139

142140
opcode: Opcode
143-
data: bytes | bytearray | memoryview
141+
data: BytesLike
144142
fin: bool = True
145143
rsv1: bool = False
146144
rsv2: bool = False
@@ -202,7 +200,7 @@ def __str__(self) -> str:
202200
@classmethod
203201
def parse(
204202
cls,
205-
read_exact: Callable[[int], Generator[None, None, bytes]],
203+
read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
206204
*,
207205
mask: bool,
208206
max_size: int | None = None,
@@ -324,6 +322,7 @@ def serialize(
324322
output.write(mask_bytes)
325323

326324
# Prepare the data.
325+
data: BytesLike
327326
if mask:
328327
data = apply_mask(self.data, mask_bytes)
329328
else:
@@ -383,7 +382,7 @@ def __str__(self) -> str:
383382
return result
384383

385384
@classmethod
386-
def parse(cls, data: bytes) -> Close:
385+
def parse(cls, data: BytesLike) -> Close:
387386
"""
388387
Parse the payload of a close frame.
389388
@@ -395,6 +394,8 @@ def parse(cls, data: bytes) -> Close:
395394
UnicodeDecodeError: If the reason isn't valid UTF-8.
396395
397396
"""
397+
if isinstance(data, memoryview):
398+
raise AssertionError("only compressed outgoing frames use memoryview")
398399
if len(data) >= 2:
399400
(code,) = struct.unpack("!H", data[:2])
400401
reason = data[2:].decode()

src/websockets/http11.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB
4848

4949

50-
def d(value: bytes) -> str:
50+
def d(value: bytes | bytearray) -> str:
5151
"""
5252
Decode a bytestring for interpolating into an error message.
5353
@@ -102,7 +102,7 @@ def exception(self) -> Exception | None: # pragma: no cover
102102
@classmethod
103103
def parse(
104104
cls,
105-
read_line: Callable[[int], Generator[None, None, bytes]],
105+
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
106106
) -> Generator[None, None, Request]:
107107
"""
108108
Parse a WebSocket handshake request.
@@ -194,7 +194,7 @@ class Response:
194194
status_code: int
195195
reason_phrase: str
196196
headers: Headers
197-
body: bytes = b""
197+
body: bytes | bytearray = b""
198198

199199
_exception: Exception | None = None
200200

@@ -210,9 +210,9 @@ def exception(self) -> Exception | None: # pragma: no cover
210210
@classmethod
211211
def parse(
212212
cls,
213-
read_line: Callable[[int], Generator[None, None, bytes]],
214-
read_exact: Callable[[int], Generator[None, None, bytes]],
215-
read_to_eof: Callable[[int], Generator[None, None, bytes]],
213+
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
214+
read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
215+
read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]],
216216
proxy: bool = False,
217217
) -> Generator[None, None, Response]:
218218
"""
@@ -276,6 +276,7 @@ def parse(
276276

277277
headers = yield from parse_headers(read_line)
278278

279+
body: bytes | bytearray
279280
if proxy:
280281
body = b""
281282
else:
@@ -299,8 +300,8 @@ def serialize(self) -> bytes:
299300

300301

301302
def parse_line(
302-
read_line: Callable[[int], Generator[None, None, bytes]],
303-
) -> Generator[None, None, bytes]:
303+
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
304+
) -> Generator[None, None, bytes | bytearray]:
304305
"""
305306
Parse a single line.
306307
@@ -326,7 +327,7 @@ def parse_line(
326327

327328

328329
def parse_headers(
329-
read_line: Callable[[int], Generator[None, None, bytes]],
330+
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
330331
) -> Generator[None, None, Headers]:
331332
"""
332333
Parse HTTP headers.
@@ -379,10 +380,10 @@ def parse_headers(
379380
def read_body(
380381
status_code: int,
381382
headers: Headers,
382-
read_line: Callable[[int], Generator[None, None, bytes]],
383-
read_exact: Callable[[int], Generator[None, None, bytes]],
384-
read_to_eof: Callable[[int], Generator[None, None, bytes]],
385-
) -> Generator[None, None, bytes]:
383+
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
384+
read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
385+
read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]],
386+
) -> Generator[None, None, bytes | bytearray]:
386387
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
387388

388389
# Since websockets only does GET requests (no HEAD, no CONNECT), all

src/websockets/legacy/framing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from .. import extensions, frames
88
from ..exceptions import PayloadTooBig, ProtocolError
9-
from ..frames import BytesLike
10-
from ..typing import Data
9+
from ..typing import BytesLike, DataLike
1110

1211

1312
try:
@@ -19,7 +18,7 @@
1918
class Frame(NamedTuple):
2019
fin: bool
2120
opcode: frames.Opcode
22-
data: bytes
21+
data: BytesLike
2322
rsv1: bool = False
2423
rsv2: bool = False
2524
rsv3: bool = False
@@ -147,7 +146,7 @@ def write(
147146
write(self.new_frame.serialize(mask=mask, extensions=extensions))
148147

149148

150-
def prepare_data(data: Data) -> tuple[int, bytes]:
149+
def prepare_data(data: DataLike) -> tuple[int, BytesLike]:
151150
"""
152151
Convert a string or byte-like object to an opcode and a bytes-like object.
153152
@@ -171,7 +170,7 @@ def prepare_data(data: Data) -> tuple[int, bytes]:
171170
raise TypeError("data must be str or bytes-like")
172171

173172

174-
def prepare_ctrl(data: Data) -> bytes:
173+
def prepare_ctrl(data: DataLike) -> bytes:
175174
"""
176175
Convert a string or byte-like object to bytes.
177176

src/websockets/legacy/protocol.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
Opcode,
4040
)
4141
from ..protocol import State
42-
from ..typing import Data, LoggerLike, Subprotocol
42+
from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol
4343
from .framing import Frame, prepare_ctrl, prepare_data
4444

4545

@@ -563,7 +563,7 @@ async def recv(self) -> Data:
563563

564564
async def send(
565565
self,
566-
message: Data | Iterable[Data] | AsyncIterable[Data],
566+
message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike],
567567
) -> None:
568568
"""
569569
Send a message.
@@ -638,7 +638,7 @@ async def send(
638638

639639
elif isinstance(message, Iterable):
640640
# Work around https://github.com/python/mypy/issues/6227
641-
message = cast(Iterable[Data], message)
641+
message = cast(Iterable[DataLike], message)
642642

643643
iter_message = iter(message)
644644
try:
@@ -678,14 +678,14 @@ async def send(
678678
# Implement aiter_message = aiter(message) without aiter
679679
# Work around https://github.com/python/mypy/issues/5738
680680
aiter_message = cast(
681-
Callable[[AsyncIterable[Data]], AsyncIterator[Data]],
681+
Callable[[AsyncIterable[DataLike]], AsyncIterator[DataLike]],
682682
type(message).__aiter__,
683683
)(message)
684684
try:
685685
# Implement fragment = anext(aiter_message) without anext
686686
# Work around https://github.com/python/mypy/issues/5738
687687
fragment = await cast(
688-
Callable[[AsyncIterator[Data]], Awaitable[Data]],
688+
Callable[[AsyncIterator[DataLike]], Awaitable[DataLike]],
689689
type(aiter_message).__anext__,
690690
)(aiter_message)
691691
except StopAsyncIteration:
@@ -788,7 +788,7 @@ async def wait_closed(self) -> None:
788788
"""
789789
await asyncio.shield(self.connection_lost_waiter)
790790

791-
async def ping(self, data: Data | None = None) -> Awaitable[float]:
791+
async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
792792
"""
793793
Send a Ping_.
794794
@@ -847,7 +847,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]:
847847

848848
return asyncio.shield(pong_waiter)
849849

850-
async def pong(self, data: Data = b"") -> None:
850+
async def pong(self, data: DataLike = b"") -> None:
851851
"""
852852
Send a Pong_.
853853
@@ -1025,10 +1025,12 @@ async def read_message(self) -> Data | None:
10251025

10261026
# Shortcut for the common case - no fragmentation
10271027
if frame.fin:
1028+
if isinstance(frame.data, memoryview):
1029+
raise AssertionError("only compressed outgoing frames use memoryview")
10281030
return frame.data.decode() if text else bytes(frame.data)
10291031

10301032
# 5.4. Fragmentation
1031-
fragments: list[Data] = []
1033+
fragments: list[DataLike] = []
10321034
max_size = self.max_size
10331035
if text:
10341036
decoder_factory = codecs.getincrementaldecoder("utf-8")
@@ -1152,7 +1154,7 @@ async def read_frame(self, max_size: int | None) -> Frame:
11521154
self.logger.debug("< %s", frame)
11531155
return frame
11541156

1155-
def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None:
1157+
def write_frame_sync(self, fin: bool, opcode: int, data: BytesLike) -> None:
11561158
frame = Frame(fin, Opcode(opcode), data)
11571159
if self.debug:
11581160
self.logger.debug("> %s", frame)
@@ -1174,7 +1176,7 @@ async def drain(self) -> None:
11741176
await self.ensure_open()
11751177

11761178
async def write_frame(
1177-
self, fin: bool, opcode: int, data: bytes, *, _state: int = State.OPEN
1179+
self, fin: bool, opcode: int, data: BytesLike, *, _state: int = State.OPEN
11781180
) -> None:
11791181
# Defensive assertion for protocol compliance.
11801182
if self.state is not _state: # pragma: no cover
@@ -1184,7 +1186,9 @@ async def write_frame(
11841186
self.write_frame_sync(fin, opcode, data)
11851187
await self.drain()
11861188

1187-
async def write_close_frame(self, close: Close, data: bytes | None = None) -> None:
1189+
async def write_close_frame(
1190+
self, close: Close, data: BytesLike | None = None
1191+
) -> None:
11881192
"""
11891193
Write a close frame if and only if the connection state is OPEN.
11901194
@@ -1538,7 +1542,7 @@ def eof_received(self) -> None:
15381542

15391543
def broadcast(
15401544
websockets: Iterable[WebSocketCommonProtocol],
1541-
message: Data,
1545+
message: DataLike,
15421546
raise_exceptions: bool = False,
15431547
) -> None:
15441548
"""

0 commit comments

Comments
 (0)