Skip to content

Commit 3136d4f

Browse files
committed
Only reuse the pointer object when it matches the _type_ of the container
Closes #107940. Also, solves a related yet undiscovered issue where an array of pointers reuses the array's memory for the pointer objects.
1 parent 71962e5 commit 3136d4f

File tree

2 files changed

+95
-24
lines changed

2 files changed

+95
-24
lines changed

Lib/test/test_ctypes/test_cast.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import sys
22
import unittest
3-
from ctypes import (Structure, Union, POINTER, cast, sizeof, addressof,
4-
c_void_p, c_char_p, c_wchar_p,
5-
c_byte, c_short, c_int)
3+
from ctypes import (Structure, Union, pointer, POINTER, sizeof, addressof,
4+
c_void_p, c_char_p, c_wchar_p, cast,
5+
c_byte, c_short, c_int, c_int16)
66

77

88
class Test(unittest.TestCase):
@@ -95,6 +95,71 @@ class MyUnion(Union):
9595
_fields_ = [("a", c_int)]
9696
self.assertRaises(TypeError, cast, array, MyUnion)
9797

98+
def test_pointer_identity(self):
99+
class Struct(Structure):
100+
_fields_ = [('a', c_int16)]
101+
Struct3 = 3 * Struct
102+
c_array = (2 * Struct3)(
103+
Struct3(Struct(a=1), Struct(a=2), Struct(a=3)),
104+
Struct3(Struct(a=4), Struct(a=5), Struct(a=6))
105+
)
106+
self.assertEqual(c_array[0][0].a, 1)
107+
self.assertEqual(c_array[0][1].a, 2)
108+
self.assertEqual(c_array[0][2].a, 3)
109+
self.assertEqual(c_array[1][0].a, 4)
110+
self.assertEqual(c_array[1][1].a, 5)
111+
self.assertEqual(c_array[1][2].a, 6)
112+
p_obj = cast(pointer(c_array), POINTER(pointer(c_array)._type_))
113+
obj = p_obj.contents
114+
self.assertEqual(obj[0][0].a, 1)
115+
self.assertEqual(obj[0][1].a, 2)
116+
self.assertEqual(obj[0][2].a, 3)
117+
self.assertEqual(obj[1][0].a, 4)
118+
self.assertEqual(obj[1][1].a, 5)
119+
self.assertEqual(obj[1][2].a, 6)
120+
p_obj = cast(pointer(c_array[0]), POINTER(pointer(c_array)._type_))
121+
obj = p_obj.contents
122+
self.assertEqual(obj[0][0].a, 1)
123+
self.assertEqual(obj[0][1].a, 2)
124+
self.assertEqual(obj[0][2].a, 3)
125+
self.assertEqual(obj[1][0].a, 4)
126+
self.assertEqual(obj[1][1].a, 5)
127+
self.assertEqual(obj[1][2].a, 6)
128+
StructPointer = POINTER(Struct)
129+
s1 = Struct(a=10)
130+
s2 = Struct(a=20)
131+
s3 = Struct(a=30)
132+
pointer_array = (3 * StructPointer)(pointer(s1), pointer(s2), pointer(s3))
133+
self.assertEqual(pointer_array[0][0].a, 10)
134+
self.assertEqual(pointer_array[1][0].a, 20)
135+
self.assertEqual(pointer_array[2][0].a, 30)
136+
self.assertEqual(pointer_array[0].contents.a, 10)
137+
self.assertEqual(pointer_array[1].contents.a, 20)
138+
self.assertEqual(pointer_array[2].contents.a, 30)
139+
p_obj = cast(pointer(pointer_array[0]), POINTER(pointer(pointer_array)._type_))
140+
obj = p_obj.contents
141+
self.assertEqual(obj[0][0].a, 10)
142+
self.assertEqual(obj[1][0].a, 20)
143+
self.assertEqual(obj[2][0].a, 30)
144+
self.assertEqual(obj[0].contents.a, 10)
145+
self.assertEqual(obj[1].contents.a, 20)
146+
self.assertEqual(obj[2].contents.a, 30)
147+
class StructWithPointers(Structure):
148+
_fields_ = [("s1", POINTER(Struct)), ("s2", POINTER(Struct))]
149+
struct = StructWithPointers(s1=pointer(s1), s2=pointer(s2))
150+
p_obj = pointer(struct)
151+
obj = p_obj.contents
152+
self.assertEqual(obj.s1[0].a, 10)
153+
self.assertEqual(obj.s2[0].a, 20)
154+
self.assertEqual(obj.s1.contents.a, 10)
155+
self.assertEqual(obj.s2.contents.a, 20)
156+
p_obj = cast(pointer(struct), POINTER(pointer(pointer_array)._type_))
157+
obj = p_obj.contents
158+
self.assertEqual(obj[0][0].a, 10)
159+
self.assertEqual(obj[1][0].a, 20)
160+
self.assertEqual(obj[0].contents.a, 10)
161+
self.assertEqual(obj[1].contents.a, 20)
162+
98163

99164
if __name__ == "__main__":
100165
unittest.main()

Modules/_ctypes/_ctypes.c

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5139,8 +5139,8 @@ static PyObject *
51395139
Pointer_get_contents(CDataObject *self, void *closure)
51405140
{
51415141
StgDictObject *stgdict;
5142-
PyObject *keep, *ptr_probe;
5143-
CDataObject *ptr2ptr;
5142+
PyObject *ptr2ptr;
5143+
CDataObject *p2p;
51445144

51455145
if (*(void **)self->b_ptr == NULL) {
51465146
PyErr_SetString(PyExc_ValueError,
@@ -5150,30 +5150,36 @@ Pointer_get_contents(CDataObject *self, void *closure)
51505150

51515151
stgdict = PyObject_stgdict((PyObject *)self);
51525152
assert(stgdict); /* Cannot be NULL for pointer instances */
5153+
assert(stgdict->proto);
51535154

5154-
keep = GetKeepedObjects(self);
5155-
if (keep != NULL) {
5156-
// check if it's a pointer to a pointer:
5157-
// pointers will have '0' key in the _objects
5158-
ptr_probe = PyDict_GetItemString(keep, "0");
5159-
5160-
if (ptr_probe != NULL) {
5161-
ptr2ptr = (CDataObject*) PyDict_GetItemString(keep, "1");
5162-
if (ptr2ptr == NULL) {
5163-
PyErr_SetString(PyExc_ValueError,
5164-
"Unexpected NULL pointer in _objects");
5165-
return NULL;
5166-
}
5167-
// don't construct a new object,
5155+
if (self->b_objects != NULL && PyDict_CheckExact(self->b_objects)) {
5156+
// Pointer_set_contents uses KeepRef(self, 1, value); we retrieve that
5157+
ptr2ptr = PyDict_GetItemString(self->b_objects, "1");
5158+
if (ptr2ptr == NULL) {
5159+
PyErr_SetString(PyExc_ValueError,
5160+
"Unexpected NULL pointer in _objects");
5161+
return NULL;
5162+
}
5163+
// if our base pointer is cast from another type,
5164+
// its `_type_` proto will be incompatible with the
5165+
// type of the object stored in `b_objects["1"]` because
5166+
// `_objects` is shared between casts and the original.
5167+
int res = PyObject_IsInstance(ptr2ptr, stgdict->proto);
5168+
if (res == -1) {
5169+
return NULL;
5170+
}
5171+
if (res) {
5172+
// It's not a cast: don't construct a new object,
51685173
// return existing one instead to preserve refcount
5174+
p2p = (CDataObject*) ptr2ptr;
51695175
assert(
5170-
*(void**) self->b_ptr == ptr2ptr->b_ptr ||
5171-
*(void**) self->b_value.c == ptr2ptr->b_ptr ||
5172-
*(void**) self->b_ptr == ptr2ptr->b_value.c ||
5173-
*(void**) self->b_value.c == ptr2ptr->b_value.c
5176+
*(void**) self->b_ptr == p2p->b_ptr ||
5177+
*(void**) self->b_value.c == p2p->b_ptr ||
5178+
*(void**) self->b_ptr == p2p->b_value.c ||
5179+
*(void**) self->b_value.c == p2p->b_value.c
51745180
); // double-check that we are returning the same thing
51755181
Py_INCREF(ptr2ptr);
5176-
return (PyObject *) ptr2ptr;
5182+
return ptr2ptr;
51775183
}
51785184
}
51795185

0 commit comments

Comments
 (0)