diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index a160dfad19..89d3d43a7a 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1606,15 +1606,9 @@ inline void object::cast() && { PYBIND11_NAMESPACE_BEGIN(detail) -// forward declaration (definition in attr.h) -struct function_record; - // forward declaration (definition in pybind11.h) -std::string generate_function_signature(const char *type_caster_name_field, - function_record *func_rec, - const std::type_info *const *types, - size_t &type_index, - size_t &arg_index); +template +std::string generate_type_signature(); // Declared in pytypes.h: template ::value, int>> @@ -1637,10 +1631,7 @@ str_attr_accessor object_api::attr_with_type_hint(const char *key) const { throw std::runtime_error("__annotations__[\"" + std::string(key) + "\"] was set already."); } - const char *text = make_caster::name.text; - - size_t unused = 0; - ann[key] = generate_function_signature(text, nullptr, nullptr, unused, unused); + ann[key] = generate_type_signature(); return {derived(), key}; } diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 9499fa7048..cf11e373dc 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -112,7 +112,6 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel size_t &arg_index) { std::string signature; bool is_starred = false; - bool is_annotation = func_rec == nullptr; // `is_return_value.top()` is true if we are currently inside the return type of the // signature. Using `@^`/`@$` we can force types to be arg/return types while `@!` pops // back to the previous state. @@ -199,9 +198,7 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel // For named arguments (py::arg()) with noconvert set, return value type is used. ++pc; if (!is_return_value.top() - && (is_annotation - || !(arg_index < func_rec->args.size() - && !func_rec->args[arg_index].convert))) { + && (!(arg_index < func_rec->args.size() && !func_rec->args[arg_index].convert))) { while (*pc != '\0' && *pc != '@') { signature += *pc++; } @@ -232,6 +229,19 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel return signature; } +template +inline std::string generate_type_signature() { + static constexpr auto caster_name_field = make_caster::name; + PYBIND11_DESCR_CONSTEXPR auto descr_types = decltype(caster_name_field)::types(); + // Create a default function_record to ensure the function signature has the proper + // configuration e.g. no_convert. + auto func_rec = function_record(); + size_t type_index = 0; + size_t arg_index = 0; + return generate_function_signature( + caster_name_field.text, &func_rec, descr_types.data(), type_index, arg_index); +} + #if defined(_MSC_VER) # define PYBIND11_COMPAT_STRDUP _strdup #else diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 5160e9f408..cbbf42178b 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -1058,6 +1058,21 @@ TEST_SUBMODULE(pytypes, m) { // Exercises py::handle overload: m.attr_with_type_hint>(py::str("set_str")) = py::set(); + struct foo_t {}; + struct foo2 {}; + struct foo3 {}; + + pybind11::class_(m, "foo"); + pybind11::class_(m, "foo2"); + pybind11::class_(m, "foo3"); + m.attr_with_type_hint("foo") = foo_t{}; + + m.attr_with_type_hint>("foo_union") = foo_t{}; + + // Include to ensure this does not crash + struct foo4 {}; + m.attr_with_type_hint("foo4") = 3; + struct Empty {}; py::class_(m, "EmptyAnnotationClass"); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index deb4b06d41..c189c27a15 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -1157,6 +1157,11 @@ def test_module_attribute_types() -> None: assert module_annotations["list_int"] == "list[typing.SupportsInt]" assert module_annotations["set_str"] == "set[str]" + assert module_annotations["foo"] == "pybind11_tests.pytypes.foo" + assert ( + module_annotations["foo_union"] + == "Union[pybind11_tests.pytypes.foo, pybind11_tests.pytypes.foo2, pybind11_tests.pytypes.foo3]" + ) @pytest.mark.skipif(