Skip to content

[mypyc] Implement str.lower() and str.upper() primitive #19375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,8 @@ PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors);
Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start);
Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end);
CPyTagged CPyStr_Ord(PyObject *obj);
PyObject *CPyStr_Lower(PyObject *self);
PyObject *CPyStr_Upper(PyObject *self);


// Bytes operations
Expand Down
119 changes: 119 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -546,3 +546,122 @@ CPyTagged CPyStr_Ord(PyObject *obj) {
PyExc_TypeError, "ord() expected a character, but a string of length %zd found", s);
return CPY_INT_TAG;
}

// Fast ASCII lower/upper tables
static const unsigned char ascii_lower_table[128] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111,
112,113,114,115,116,117,118,119,120,121,122, 91, 92, 93, 94, 95,
96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111,
112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127
};

static const unsigned char ascii_upper_table[128] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,123,124,125,126,127
};

// Helper for lower/upper: get the lower/upper code point for a character
static inline Py_UCS4 tolower_ucs4(Py_UCS4 ch) {
if (ch < 128) {
return ascii_lower_table[ch];
}
#ifdef Py_UNICODE_TOLOWER
return Py_UNICODE_TOLOWER(ch);
#else
// fallback: no-op for non-ASCII if macro is unavailable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to expect that Py_UNICODE_TOLOWER is not available? We shouldn't break functionality if a dependency is missing -- it's better to fail compilation. It seems to me that the best option is to remove the #ifdef and assume ``Py_UNICODE_TOLOWER` is defined.

return ch;
#endif
}

static inline Py_UCS4 toupper_ucs4(Py_UCS4 ch) {
if (ch < 128) {
return ascii_upper_table[ch];
}
#ifdef Py_UNICODE_TOUPPER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.

return Py_UNICODE_TOUPPER(ch);
#else
// fallback: no-op for non-ASCII if macro is unavailable
return ch;
#endif
}

// Implementation of s.lower()
PyObject *CPyStr_Lower(PyObject *self) {
if (PyUnicode_READY(self) == -1)
return NULL;
Py_ssize_t len = PyUnicode_GET_LENGTH(self);
int kind = PyUnicode_KIND(self);
void *data = PyUnicode_DATA(self);

// Fast path: check if already all lower
int unchanged = 1;
for (Py_ssize_t i = 0; i < len; i++) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
if (tolower_ucs4(ch) != ch) {
unchanged = 0;
break;
}
}
if (unchanged) {
return Py_NewRef(self);
}

Py_UCS4 maxchar = PyUnicode_MAX_CHAR_VALUE(self);
PyObject *res = PyUnicode_New(len, maxchar);
if (!res)
return NULL;
int res_kind = PyUnicode_KIND(res);
void *res_data = PyUnicode_DATA(res);

for (Py_ssize_t i = 0; i < len; i++) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
Py_UCS4 lower = tolower_ucs4(ch);
PyUnicode_WRITE(res_kind, res_data, i, lower);
}
return res;
}

// Implementation of s.upper()
PyObject *CPyStr_Upper(PyObject *self) {
if (PyUnicode_READY(self) == -1)
return NULL;
Py_ssize_t len = PyUnicode_GET_LENGTH(self);
int kind = PyUnicode_KIND(self);
void *data = PyUnicode_DATA(self);

int unchanged = 1;
for (Py_ssize_t i = 0; i < len; i++) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
if (toupper_ucs4(ch) != ch) {
unchanged = 0;
break;
}
}
if (unchanged) {
return Py_NewRef(self);
}

Py_UCS4 maxchar = PyUnicode_MAX_CHAR_VALUE(self);
PyObject *res = PyUnicode_New(len, maxchar);
if (!res)
return NULL;
int res_kind = PyUnicode_KIND(res);
void *res_data = PyUnicode_DATA(res);

for (Py_ssize_t i = 0; i < len; i++) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
Py_UCS4 upper = toupper_ucs4(ch);
PyUnicode_WRITE(res_kind, res_data, i, upper);
}
return res;
}
18 changes: 18 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,21 @@
c_function_name="CPyStr_Ord",
error_kind=ERR_MAGIC,
)

# str.lower()
method_op(
name="lower",
arg_types=[str_rprimitive],
return_type=str_rprimitive,
c_function_name="CPyStr_Lower",
error_kind=ERR_MAGIC,
)

# str.upper()
method_op(
name="upper",
arg_types=[str_rprimitive],
return_type=str_rprimitive,
c_function_name="CPyStr_Upper",
error_kind=ERR_MAGIC,
)
3 changes: 2 additions & 1 deletion mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def lstrip(self, item: Optional[str] = None) -> str: pass
def rstrip(self, item: Optional[str] = None) -> str: pass
def join(self, x: Iterable[str]) -> str: pass
def format(self, *args: Any, **kwargs: Any) -> str: ...
def upper(self) -> str: ...
def startswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def endswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
Expand All @@ -122,6 +121,8 @@ def rpartition(self, sep: str, /) -> Tuple[str, str, str]: ...
def removeprefix(self, prefix: str, /) -> str: ...
def removesuffix(self, suffix: str, /) -> str: ...
def islower(self) -> bool: ...
def lower(self) -> str: ...
def upper(self) -> str: ...

class float:
def __init__(self, x: object) -> None: pass
Expand Down
20 changes: 20 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,23 @@ L0:
r3 = box(native_int, r1)
r4 = unbox(int, r3)
return r4

[case testLower]
def do_lower(s: str) -> str:
return s.lower()
[out]
def do_lower(s):
s, r0 :: str
L0:
r0 = CPyStr_Lower(s)
return r0

[case testUpper]
def do_upper(s: str) -> str:
return s.upper()
[out]
def do_upper(s):
s, r0 :: str
L0:
r0 = CPyStr_Upper(s)
return r0
18 changes: 18 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -906,3 +906,21 @@ def test_count_multi_start_end_emoji() -> None:
assert string.count("😴😴😴", 0, 12) == 1, string.count("😴😴😴", 0, 12)
assert string.count("🚀🚀🚀", 0, 12) == 2, string.count("🚀🚀🚀", 0, 12)
assert string.count("ñññ", 0, 12) == 1, string.count("ñññ", 0, 12)

[case testLower]
def test_str_lower() -> None:
assert "".lower() == ""
assert "ABC".lower() == "abc"
assert "abc".lower() == "abc"
assert "AbC123".lower() == "abc123"
assert "áÉÍ".lower() == "áéí"
assert "😴🚀".lower() == "😴🚀"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test special cases (verify that this agrees with normal Python semantics):

  • 'SS'.lower() == 'ss'
  • 'Σ'.lower()
  • 'İ'.lower() (changes length!)


[case testUpper]
def test_str_upper() -> None:
assert "".upper() == ""
assert "abc".upper() == "ABC"
assert "ABC".upper() == "ABC"
assert "AbC123".upper() == "ABC123"
assert "áéí".upper() == "ÁÉÍ"
assert "😴🚀".upper() == "😴🚀"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test special case (verify that this agrees with normal Python semantics):

  • 'ß'.upper() == 'SS'
  • 'ffi'.upper() (length increases!)