Skip to content

Commit 5b642fc

Browse files
committed
Fix casting of classes with mixed polymorphic inheritance
A polymorphic class can inherit from a non-polymorphic base.
1 parent 2b4477e commit 5b642fc

File tree

5 files changed

+91
-4
lines changed

5 files changed

+91
-4
lines changed

include/pybind11/attr.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ struct function_record {
200200
/// Special data structure which (temporarily) holds metadata about a bound class
201201
struct type_record {
202202
PYBIND11_NOINLINE type_record()
203-
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { }
203+
: multiple_inheritance(false), polymorphic(false), dynamic_attr(false),
204+
buffer_protocol(false), module_local(false) { }
204205

205206
/// Handle to the parent scope
206207
handle scope;
@@ -238,6 +239,9 @@ struct type_record {
238239
/// Multiple inheritance marker
239240
bool multiple_inheritance : 1;
240241

242+
/// Type is polymorphic in C++
243+
bool polymorphic : 1;
244+
241245
/// Does the class manage a __dict__?
242246
bool dynamic_attr : 1;
243247

@@ -250,6 +254,7 @@ struct type_record {
250254
/// Is the class definition local to the module shared object?
251255
bool module_local : 1;
252256

257+
/// Add a base as a template argument -- allows casting to base for non-simple types
253258
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
254259
auto base_info = detail::get_type_info(base, false);
255260
if (!base_info) {
@@ -276,6 +281,24 @@ struct type_record {
276281
if (caster)
277282
base_info->implicit_casts.emplace_back(type, caster);
278283
}
284+
285+
/// Add a base as a runtime argument -- only for simple types
286+
PYBIND11_NOINLINE void add_base(handle base) {
287+
if (!base || !PyType_Check(base.ptr()))
288+
pybind11_fail("generic_type: type \"" + std::string(name) + "\" "
289+
"is trying to register a non-type object as a base");
290+
291+
auto base_ptr = (PyTypeObject *) base.ptr();
292+
auto base_info = detail::get_type_info(base_ptr);
293+
if (polymorphic != base_info->polymorphic) {
294+
pybind11_fail("generic_type: type \"" + std::string(name) + "\" is polymorphic, "
295+
"but its base \"" + std::string(base_ptr->tp_name) + "\" is not. "
296+
"In this case, the base must be specified as a template argument: "
297+
"py::class_<T, Base>(...) instead of py::class_<T>(..., base).");
298+
}
299+
300+
bases.append(base);
301+
}
279302
};
280303

281304
inline function_call::function_call(function_record &f, handle p) :
@@ -392,7 +415,7 @@ template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
392415
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
393416
template <typename T>
394417
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
395-
static void init(const handle &h, type_record *r) { r->bases.append(h); }
418+
static void init(const handle &h, type_record *r) { r->add_base(h); }
396419
};
397420

398421
/// Process a parent class attribute (deprecated, does not support multiple inheritance)

include/pybind11/detail/internals.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,16 @@ struct type_info {
104104
bool simple_type : 1;
105105
/* True if there is no multiple inheritance in this type's inheritance tree */
106106
bool simple_ancestors : 1;
107+
/* Type is polymorphic in C++ */
108+
bool polymorphic : 1;
107109
/* for base vs derived holder_type checks */
108110
bool default_holder : 1;
109111
/* true if this is a type registered with py::module_local */
110112
bool module_local : 1;
111113
};
112114

113115
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
114-
#define PYBIND11_INTERNALS_VERSION 1
116+
#define PYBIND11_INTERNALS_VERSION 2
115117

116118
#if defined(WITH_THREAD)
117119
# define PYBIND11_INTERNALS_KIND ""

include/pybind11/pybind11.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ class generic_type : public object {
907907
tinfo->dealloc = rec.dealloc;
908908
tinfo->simple_type = true;
909909
tinfo->simple_ancestors = true;
910+
tinfo->polymorphic = rec.polymorphic;
910911
tinfo->default_holder = rec.default_holder;
911912
tinfo->module_local = rec.module_local;
912913

@@ -925,7 +926,12 @@ class generic_type : public object {
925926
}
926927
else if (rec.bases.size() == 1) {
927928
auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr());
928-
tinfo->simple_ancestors = parent_tinfo->simple_ancestors;
929+
if (tinfo->polymorphic == parent_tinfo->polymorphic) {
930+
tinfo->simple_ancestors = parent_tinfo->simple_ancestors;
931+
} else {
932+
mark_parents_nonsimple(tinfo->type);
933+
tinfo->simple_ancestors = false;
934+
}
929935
}
930936

931937
if (rec.module_local) {
@@ -936,6 +942,7 @@ class generic_type : public object {
936942
}
937943

938944
/// Helper function which tags all parents of a type using mult. inheritance
945+
/// or a polymorphic type which inherits from a non-polymorphic base
939946
void mark_parents_nonsimple(PyTypeObject *value) {
940947
auto t = reinterpret_borrow<tuple>(value->tp_bases);
941948
for (handle h : t) {
@@ -1054,6 +1061,7 @@ class class_ : public detail::generic_type {
10541061
record.holder_size = sizeof(holder_type);
10551062
record.init_instance = init_instance;
10561063
record.dealloc = dealloc;
1064+
record.polymorphic = std::is_polymorphic<type>::value;
10571065
record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
10581066

10591067
set_operator_new<type>(&record);

tests/test_class.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,36 @@ TEST_SUBMODULE(class_, m) {
8181
m.def("pet_name_species", [](const Pet &pet) { return pet.name() + " is a " + pet.species(); });
8282
m.def("dog_bark", [](const Dog &dog) { return dog.bark(); });
8383

84+
// test_mixed_polymorphic_inheritance
85+
struct NonPolymorphicBase {
86+
std::int64_t a, b;
87+
};
88+
struct PolymorphicDerived : NonPolymorphicBase {
89+
PolymorphicDerived() : NonPolymorphicBase{1, 2} { }
90+
virtual ~PolymorphicDerived() { }
91+
};
92+
93+
py::class_<NonPolymorphicBase>(m, "NonPolymorphicBase")
94+
.def_readwrite("a", &NonPolymorphicBase::a)
95+
.def_readwrite("b", &NonPolymorphicBase::b);
96+
97+
py::class_<PolymorphicDerived, NonPolymorphicBase>(m, "PolymorphicDerived")
98+
.def(py::init<>());
99+
100+
m.def("call_with_nonpolymorphic_base", [](const NonPolymorphicBase &x) { return x.b; });
101+
m.def("call_with_polymorphic_derived", [](const PolymorphicDerived &x) { return x.b; });
102+
103+
m.def("register_mixed_polymorphic_base_at_runtime", []() {
104+
struct LocalPolymorphicDerived : NonPolymorphicBase {
105+
virtual ~LocalPolymorphicDerived() = default;
106+
};
107+
108+
auto module = py::module::import("pybind11_tests").attr("class_");
109+
auto base = module.attr("NonPolymorphicBase");
110+
// Expected to throw
111+
py::class_<LocalPolymorphicDerived>(module, "LocalPolymorphicDerived", base);
112+
});
113+
84114
// test_automatic_upcasting
85115
struct BaseClass { virtual ~BaseClass() {} };
86116
struct DerivedClass1 : BaseClass { };

tests/test_class.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,30 @@ def test_inheritance(msg):
7676
assert "No constructor defined!" in str(excinfo.value)
7777

7878

79+
def test_mixed_polymorphic_inheritance():
80+
"""A polymorphic class can inherit members from a non-polymorphic base"""
81+
import re
82+
83+
class PolymorphicDerived(m.PolymorphicDerived):
84+
def __init__(self):
85+
m.PolymorphicDerived.__init__(self)
86+
87+
for x in m.PolymorphicDerived(), PolymorphicDerived():
88+
assert (x.a, x.b) == (1, 2)
89+
x.a = 11
90+
x.b = 22
91+
assert (x.a, x.b) == (11, 22)
92+
assert m.call_with_nonpolymorphic_base(x) == 22
93+
assert m.call_with_polymorphic_derived(x) == 22
94+
95+
with pytest.raises(RuntimeError) as excinfo:
96+
m.register_mixed_polymorphic_base_at_runtime()
97+
assert re.match('generic_type: type ".*LocalPolymorphicDerived" is polymorphic, '
98+
'but its base ".*NonPolymorphicBase" is not', str(excinfo.value))
99+
assert ('In this case, the base must be specified as a template argument: '
100+
'py::class_<T, Base>(...) instead of py::class_<T>(..., base).') in str(excinfo.value)
101+
102+
79103
def test_automatic_upcasting():
80104
assert type(m.return_class_1()).__name__ == "DerivedClass1"
81105
assert type(m.return_class_2()).__name__ == "DerivedClass2"

0 commit comments

Comments
 (0)