Skip to content

Commit c6c8a17

Browse files
[numpy] Fix breakage in users of Python protobufs under NumPy 2.3rc1. (#22171)
As of NumPy 2.3.0rc1, numpy.bool scalars can no longer be interpreted as index values (https://github.com/numpy/numpy/releases/tag/v2.3.0rc1). This causes protobuf no longer to accept a np.bool scalar as a legal value for a boolean field. We have two options: a) either we can change protobuf so that it continues to accept NumPy boolean scalars (this change), or b) decide that protobuf should reject NumPy boolean scalars and that users must update their code to cast to a Python bool explicitly. I have no strong opinion as to which, but option (a) seems less disruptive. No test updates are needed: the existing tests fail under NumPy 2.3. PiperOrigin-RevId: 766629310 Co-authored-by: Peter Hawkins <phawkins@google.com>
1 parent afede60 commit c6c8a17

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

python/convert.c

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,35 @@ bool PyUpb_IsNumpyNdarray(PyObject* obj, const upb_FieldDef* f) {
208208
return is_ndarray;
209209
}
210210

211+
bool PyUpb_IsNumpyBoolScalar(PyObject* obj) {
212+
PyObject* type_module_obj =
213+
PyObject_GetAttrString((PyObject*)Py_TYPE(obj), "__module__");
214+
bool is_numpy = !strcmp(PyUpb_GetStrData(type_module_obj), "numpy");
215+
Py_DECREF(type_module_obj);
216+
if (!is_numpy) {
217+
return false;
218+
}
219+
220+
PyObject* type_name_obj =
221+
PyObject_GetAttrString((PyObject*)Py_TYPE(obj), "__name__");
222+
bool is_bool = !strcmp(PyUpb_GetStrData(type_name_obj), "bool");
223+
Py_DECREF(type_name_obj);
224+
if (!is_bool) {
225+
return false;
226+
}
227+
return true;
228+
}
229+
230+
static bool PyUpb_GetBool(PyObject* obj, const upb_FieldDef* f, bool* val) {
231+
if (PyUpb_IsNumpyNdarray(obj, f)) return false;
232+
if (PyUpb_IsNumpyBoolScalar(obj)) {
233+
*val = PyObject_IsTrue(obj);
234+
return !PyErr_Occurred();
235+
}
236+
*val = PyLong_AsLong(obj);
237+
return !PyErr_Occurred();
238+
}
239+
211240
bool PyUpb_PyToUpb(PyObject* obj, const upb_FieldDef* f, upb_MessageValue* val,
212241
upb_Arena* arena) {
213242
switch (upb_FieldDef_CType(f)) {
@@ -230,9 +259,7 @@ bool PyUpb_PyToUpb(PyObject* obj, const upb_FieldDef* f, upb_MessageValue* val,
230259
val->double_val = PyFloat_AsDouble(obj);
231260
return !PyErr_Occurred();
232261
case kUpb_CType_Bool:
233-
if (PyUpb_IsNumpyNdarray(obj, f)) return false;
234-
val->bool_val = PyLong_AsLong(obj);
235-
return !PyErr_Occurred();
262+
return PyUpb_GetBool(obj, f, &val->bool_val);
236263
case kUpb_CType_Bytes: {
237264
char* ptr;
238265
Py_ssize_t size;

python/google/protobuf/internal/type_checkers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,21 @@ class BoolValueChecker(object):
113113
"""Type checker used for bool fields."""
114114

115115
def CheckValue(self, proposed_value):
116-
if not hasattr(proposed_value, '__index__') or (
117-
type(proposed_value).__module__ == 'numpy' and
116+
if not hasattr(proposed_value, '__index__'):
117+
# Under NumPy 2.3, numpy.bool does not have an __index__ method.
118+
if (type(proposed_value).__module__ == 'numpy' and
119+
type(proposed_value).__name__ == 'bool'):
120+
return bool(proposed_value)
121+
message = ('%.1024r has type %s, but expected one of: %s' %
122+
(proposed_value, type(proposed_value), (bool, int)))
123+
raise TypeError(message)
124+
125+
if (type(proposed_value).__module__ == 'numpy' and
118126
type(proposed_value).__name__ == 'ndarray'):
119127
message = ('%.1024r has type %s, but expected one of: %s' %
120128
(proposed_value, type(proposed_value), (bool, int)))
121129
raise TypeError(message)
130+
122131
return bool(proposed_value)
123132

124133
def DefaultValue(self):

python/google/protobuf/pyext/message.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,8 +576,20 @@ bool CheckAndGetFloat(PyObject* arg, float* value) {
576576

577577
bool CheckAndGetBool(PyObject* arg, bool* value) {
578578
long long_value = PyLong_AsLong(arg); // NOLINT
579-
if (!strcmp(Py_TYPE(arg)->tp_name, "numpy.ndarray") ||
580-
(long_value == -1 && PyErr_Occurred())) {
579+
if (long_value == -1 && PyErr_Occurred()) {
580+
// In NumPy 2.3, numpy.bool does not have an __index__ method and cannot
581+
// be converted to a long using PyLong_AsLong.
582+
if (!strcmp(Py_TYPE(arg)->tp_name, "numpy.bool")) {
583+
PyErr_Clear();
584+
int is_true = PyObject_IsTrue(arg);
585+
if (is_true >= 0) {
586+
*value = static_cast<bool>(is_true);
587+
return true;
588+
}
589+
}
590+
FormatTypeError(arg, "int, bool");
591+
return false;
592+
} else if (!strcmp(Py_TYPE(arg)->tp_name, "numpy.ndarray")) {
581593
FormatTypeError(arg, "int, bool");
582594
return false;
583595
}

0 commit comments

Comments
 (0)