Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Commit ab3fdd3

Browse files
BrandonTheBuildertekknolagi
authored andcommitted
Accept Byteslike in decoders
Summary: The python builtins were already checking _bytes_like_guard when type checking. This updates the actual decode functions to use Byteslike objects instead of assuming they are Bytes or a Bytearray. Based on Facebook D27862662
1 parent 40f131e commit ab3fdd3

File tree

2 files changed

+75
-34
lines changed

2 files changed

+75
-34
lines changed

library/_codecs_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
import _codecs
33
import unittest
4+
from array import array
45

56
from test_support import pyro_only
67

@@ -369,6 +370,11 @@ def test_decode_ascii_with_well_formed_ascii_returns_string(self):
369370
self.assertEqual(decoded, "hello")
370371
self.assertEqual(consumed, 5)
371372

373+
def test_decode_ascii_with_well_formed_ascii_array_returns_string(self):
374+
decoded, consumed = _codecs.ascii_decode(array("B", b"hello"))
375+
self.assertEqual(decoded, "hello")
376+
self.assertEqual(consumed, 5)
377+
372378
def test_decode_ascii_with_well_formed_ascii_bytearray_returns_string(self):
373379
decoded, consumed = _codecs.ascii_decode(bytearray(b"hello"))
374380
self.assertEqual(decoded, "hello")
@@ -384,6 +390,11 @@ class B(bytearray):
384390
self.assertEqual(decoded, "hello")
385391
self.assertEqual(consumed, 5)
386392

393+
def test_decode_ascii_with_well_formed_ascii_memoryview_returns_string(self):
394+
decoded, consumed = _codecs.ascii_decode(memoryview(b"hello"))
395+
self.assertEqual(decoded, "hello")
396+
self.assertEqual(consumed, 5)
397+
387398
def test_decode_ascii_with_custom_error_handler_returns_string(self):
388399
_codecs.register_error("test", lambda x: ("-testing-", x.end))
389400
decoded, consumed = _codecs.ascii_decode(b"ab\x90c", "test")
@@ -457,6 +468,11 @@ def test_decode_latin_1_with_ascii_returns_string(self):
457468
self.assertEqual(decoded, "hello")
458469
self.assertEqual(consumed, 5)
459470

471+
def test_decode_latin_1_with_ascii_array_returns_string(self):
472+
decoded, consumed = _codecs.latin_1_decode(array("B", b"hello"))
473+
self.assertEqual(decoded, "hello")
474+
self.assertEqual(consumed, 5)
475+
460476
def test_decode_latin_1_with_ascii_bytearray_returns_string(self):
461477
decoded, consumed = _codecs.latin_1_decode(bytearray(b"hello"))
462478
self.assertEqual(decoded, "hello")
@@ -470,6 +486,11 @@ class B(bytearray):
470486
self.assertEqual(decoded, "hello")
471487
self.assertEqual(consumed, 5)
472488

489+
def test_decode_latin_1_with_ascii_memoryview_returns_string(self):
490+
decoded, consumed = _codecs.latin_1_decode(memoryview(b"hello"))
491+
self.assertEqual(decoded, "hello")
492+
self.assertEqual(consumed, 5)
493+
473494
def test_decode_latin_1_with_latin_1_returns_string(self):
474495
decoded, consumed = _codecs.latin_1_decode(b"\x7D\x7E\x7F\x80\x81\x82")
475496
self.assertEqual(decoded, "\x7D\x7E\x7F\x80\x81\x82")
@@ -495,6 +516,13 @@ def test_decode_unicode_escape_with_well_formed_latin_1_returns_string(self):
495516
self.assertEqual(decoded, "hello\x95")
496517
self.assertEqual(consumed, 6)
497518

519+
def test_decode_unicode_escape_with_well_formed_latin_1_array_returns_string(
520+
self,
521+
):
522+
decoded, consumed = _codecs.unicode_escape_decode(array("B", b"hello\x95"))
523+
self.assertEqual(decoded, "hello\x95")
524+
self.assertEqual(consumed, 6)
525+
498526
def test_decode_unicode_escape_with_well_formed_latin_1_bytearray_returns_string(
499527
self,
500528
):
@@ -510,6 +538,13 @@ class B(bytearray):
510538
self.assertEqual(decoded, "hello\x95")
511539
self.assertEqual(consumed, 6)
512540

541+
def test_decode_unicode_escape_with_well_formed_latin_1_memoryview_returns_string(
542+
self,
543+
):
544+
decoded, consumed = _codecs.unicode_escape_decode(memoryview(b"hello\x95"))
545+
self.assertEqual(decoded, "hello\x95")
546+
self.assertEqual(consumed, 6)
547+
513548
def test_decode_unicode_escape_with_escaped_back_slash_returns_string(self):
514549
decoded, consumed = _codecs.unicode_escape_decode(b"hello\\x95")
515550
self.assertEqual(decoded, "hello\x95")
@@ -614,13 +649,27 @@ def test_decode_raw_unicode_escape_with_escaped_back_slash_returns_string(self):
614649
self.assertEqual(decoded, "hello\\x95")
615650
self.assertEqual(consumed, 9)
616651

652+
def test_decode_raw_unicode_escape_with_well_formed_latin_1_array_returns_string(
653+
self,
654+
):
655+
decoded, consumed = _codecs.raw_unicode_escape_decode(array("B", b"hello\x95"))
656+
self.assertEqual(decoded, "hello\x95")
657+
self.assertEqual(consumed, 6)
658+
617659
def test_decode_raw_unicode_escape_with_well_formed_latin_1_bytearray_returns_string(
618660
self,
619661
):
620662
decoded, consumed = _codecs.raw_unicode_escape_decode(bytearray(b"hello\x95"))
621663
self.assertEqual(decoded, "hello\x95")
622664
self.assertEqual(consumed, 6)
623665

666+
def test_decode_raw_unicode_escape_with_well_formed_latin_1_memoryview_returns_string(
667+
self,
668+
):
669+
decoded, consumed = _codecs.raw_unicode_escape_decode(memoryview(b"hello\x95"))
670+
self.assertEqual(decoded, "hello\x95")
671+
self.assertEqual(consumed, 6)
672+
624673
def test_decode_raw_unicode_escape_with_latin_1_bytearray_subclass_returns_string(
625674
self,
626675
):
@@ -738,6 +787,13 @@ def test_decode_utf_8_with_well_formed_utf_8_returns_string(self):
738787
self.assertEqual(decoded, "\U0001f192h\xe4l\u2cc0")
739788
self.assertEqual(consumed, 11)
740789

790+
def test_decode_utf_8_with_well_formed_utf8_array_returns_string(self):
791+
decoded, consumed = _codecs.utf_8_decode(
792+
array("B", b"\xf0\x9f\x86\x92h\xc3\xa4l\xe2\xb3\x80")
793+
)
794+
self.assertEqual(decoded, "\U0001f192h\xe4l\u2cc0")
795+
self.assertEqual(consumed, 11)
796+
741797
def test_decode_utf_8_with_well_formed_utf8_bytearray_returns_string(self):
742798
decoded, consumed = _codecs.utf_8_decode(
743799
bytearray(b"\xf0\x9f\x86\x92h\xc3\xa4l\xe2\xb3\x80")
@@ -755,6 +811,13 @@ class B(bytearray):
755811
self.assertEqual(decoded, "\U0001f192h\xe4l\u2cc0")
756812
self.assertEqual(consumed, 11)
757813

814+
def test_decode_utf_8_with_well_formed_utf8_memoryview_returns_string(self):
815+
decoded, consumed = _codecs.utf_8_decode(
816+
memoryview(b"\xf0\x9f\x86\x92h\xc3\xa4l\xe2\xb3\x80")
817+
)
818+
self.assertEqual(decoded, "\U0001f192h\xe4l\u2cc0")
819+
self.assertEqual(consumed, 11)
820+
758821
def test_decode_utf_8_with_custom_error_handler_returns_string(self):
759822
_codecs.register_error("test", lambda x: ("-testing-", x.end))
760823
decoded, consumed = _codecs.utf_8_decode(b"ab\x90c", "test")

runtime/under-codecs-module.cpp

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ static SymbolId lookupSymbolForErrorHandler(const Str& error) {
3535
return SymbolId::kInvalid;
3636
}
3737

38-
static int asciiDecode(Thread* thread, const StrArray& dst, const Bytes& src,
39-
word start, word end) {
38+
static int asciiDecode(Thread* thread, const StrArray& dst,
39+
const Byteslike& src, word start, word end) {
4040
// TODO(T41032331): Implement a fastpass to read longs instead of chars
4141
Runtime* runtime = thread->runtime();
4242
for (word i = start; i < end; i++) {
@@ -57,16 +57,8 @@ RawObject FUNC(_codecs, _ascii_decode)(Thread* thread, Arguments args) {
5757
word index = intUnderlying(args.get(2)).asWord();
5858
StrArray dst(&scope, args.get(3));
5959

60-
word length;
61-
Bytes bytes(&scope, Bytes::empty());
62-
if (runtime->isInstanceOfBytearray(*data)) {
63-
Bytearray array(&scope, *data);
64-
bytes = array.items();
65-
length = array.numItems();
66-
} else {
67-
bytes = bytesUnderlying(*data);
68-
length = bytes.length();
69-
}
60+
Byteslike bytes(&scope, thread, *data);
61+
word length = bytes.length();
7062
runtime->strArrayEnsureCapacity(thread, dst, length);
7163
word outpos = asciiDecode(thread, dst, bytes, index, length);
7264
if (outpos == length) {
@@ -176,7 +168,7 @@ RawObject FUNC(_codecs, _ascii_encode)(Thread* thread, Arguments args) {
176168
// -1 if no value should be written, and -2 if an error occurred. Sets the
177169
// iterating variable to where decoding should continue, and sets
178170
// invalid_escape_index if it doesn't recognize the escape sequence.
179-
static int32_t decodeEscaped(const Bytes& bytes, word* i,
171+
static int32_t decodeEscaped(const Byteslike& bytes, word* i,
180172
word* invalid_escape_index) {
181173
word length = bytes.length();
182174
switch (byte ch = bytes.byteAt((*i)++)) {
@@ -264,7 +256,7 @@ RawObject FUNC(_codecs, _escape_decode)(Thread* thread, Arguments args) {
264256
}
265257
DCHECK(runtime->isInstanceOfStr(args.get(2)),
266258
"Third arg to _escape_decode must be str");
267-
Bytes bytes(&scope, bytesUnderlying(*bytes_obj));
259+
Byteslike bytes(&scope, thread, *bytes_obj);
268260
Str errors(&scope, strUnderlying(args.get(1)));
269261

270262
Bytearray dst(&scope, runtime->newBytearray());
@@ -333,15 +325,8 @@ RawObject FUNC(_codecs, _latin_1_decode)(Thread* thread, Arguments args) {
333325
Object data(&scope, args.get(0));
334326
StrArray array(&scope, runtime->newStrArray());
335327
word length;
336-
Bytes bytes(&scope, Bytes::empty());
337-
if (runtime->isInstanceOfBytearray(*data)) {
338-
Bytearray byte_array(&scope, *data);
339-
bytes = byte_array.items();
340-
length = byte_array.numItems();
341-
} else {
342-
bytes = bytesUnderlying(*data);
343-
length = bytes.length();
344-
}
328+
Byteslike bytes(&scope, thread, *data);
329+
length = bytes.length();
345330
runtime->strArrayEnsureCapacity(thread, array, length);
346331
// First, try a quick ASCII decoding
347332
word num_bytes = asciiDecode(thread, array, bytes, 0, length);
@@ -669,7 +654,8 @@ enum Utf8DecoderResult {
669654
// function returns specific values for errors to determine whether they could
670655
// be caused by incremental decoding, or if they would be an error no matter
671656
// what other bytes might be streamed in later.
672-
static Utf8DecoderResult isValidUtf8Codepoint(const Bytes& bytes, word index) {
657+
static Utf8DecoderResult isValidUtf8Codepoint(const Byteslike& bytes,
658+
word index) {
673659
word length = bytes.length();
674660
byte ch = bytes.byteAt(index);
675661
if (ch <= kMaxASCII) {
@@ -781,16 +767,8 @@ RawObject FUNC(_codecs, _utf_8_decode)(Thread* thread, Arguments args) {
781767
StrArray dst(&scope, args.get(3));
782768

783769
word length;
784-
Bytes bytes(&scope, Bytes::empty());
785-
// TODO(T45849551): Handle any bytes-like object
786-
if (runtime->isInstanceOfBytearray(*data)) {
787-
Bytearray array(&scope, *data);
788-
bytes = array.items();
789-
length = array.numItems();
790-
} else {
791-
bytes = bytesUnderlying(*data);
792-
length = bytes.length();
793-
}
770+
Byteslike bytes(&scope, thread, *data);
771+
length = bytes.length();
794772
runtime->strArrayEnsureCapacity(thread, dst, length);
795773
word i = asciiDecode(thread, dst, bytes, index, length);
796774
if (i == length) {

0 commit comments

Comments
 (0)