Skip to content

Commit 077d0b6

Browse files
committed
Added return_descr/arg_descr for correct typing in typing::Callable
1 parent 9877aca commit 077d0b6

File tree

5 files changed

+59
-17
lines changed

5 files changed

+59
-17
lines changed

include/pybind11/detail/descr.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,15 @@ constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
174174
return const_name("{") + descr + const_name("}");
175175
}
176176

177+
template <size_t N, typename... Ts>
178+
constexpr descr<N + 4, Ts...> arg_descr(const descr<N, Ts...> &descr) {
179+
return const_name("@^") + descr + const_name("@^");
180+
}
181+
182+
template <size_t N, typename... Ts>
183+
constexpr descr<N + 4, Ts...> return_descr(const descr<N, Ts...> &descr) {
184+
return const_name("@$") + descr + const_name("@$");
185+
}
186+
177187
PYBIND11_NAMESPACE_END(detail)
178188
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

include/pybind11/pybind11.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,11 @@ class cpp_function : public function {
440440
std::string signature;
441441
size_t type_index = 0, arg_index = 0;
442442
bool is_starred = false;
443+
// `is_return_value` is true if we are currently inside the return type of the signature.
444+
// The same is true for `use_return_value`, except for forced usage of arg/return type
445+
// using @^/@$.
443446
bool is_return_value = false;
447+
bool use_return_value = false;
444448
for (const auto *pc = text; *pc != '\0'; ++pc) {
445449
const auto c = *pc;
446450

@@ -495,11 +499,21 @@ class cpp_function : public function {
495499
signature += detail::quote_cpp_type_name(detail::clean_type_id(t->name()));
496500
}
497501
} else if (c == '@') {
502+
// `@^ ... @^` and `@$ ... @$` are used to force arg/return value type (see
503+
// typing::Callable/detail::arg_descr/detail::return_descr)
504+
if ((*(pc + 1) == '^' && is_return_value)
505+
|| (*(pc + 1) == '$' && !is_return_value)) {
506+
use_return_value = !use_return_value;
507+
}
508+
if (*(pc + 1) == '^' || *(pc + 1) == '$') {
509+
++pc;
510+
continue;
511+
}
498512
// Handle types that differ depending on whether they appear
499-
// in an argument or a return value position
500-
// For named arguments (py::arg()) with noconvert set, use return value type
513+
// in an argument or a return value position (see io_name<text1, text2>).
514+
// For named arguments (py::arg()) with noconvert set, return value type is used.
501515
++pc;
502-
if (!is_return_value
516+
if (!use_return_value
503517
&& !(arg_index < rec->args.size() && !rec->args[arg_index].convert)) {
504518
while (*pc && *pc != '@')
505519
signature += *pc++;
@@ -518,6 +532,7 @@ class cpp_function : public function {
518532
} else {
519533
if (c == '-' && *(pc + 1) == '>') {
520534
is_return_value = true;
535+
use_return_value = true;
521536
}
522537
signature += c;
523538
}

include/pybind11/typing.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,19 @@ template <typename Return, typename... Args>
188188
struct handle_type_name<typing::Callable<Return(Args...)>> {
189189
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
190190
static constexpr auto name
191-
= const_name("Callable[[") + ::pybind11::detail::concat(make_caster<Args>::name...)
192-
+ const_name("], ") + make_caster<retval_type>::name + const_name("]");
191+
= const_name("Callable[[")
192+
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
193+
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
194+
+ const_name("]");
193195
};
194196

195197
template <typename Return>
196198
struct handle_type_name<typing::Callable<Return(ellipsis)>> {
197199
// PEP 484 specifies this syntax for defining only return types of callables
198200
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
199-
static constexpr auto name
200-
= const_name("Callable[..., ") + make_caster<retval_type>::name + const_name("]");
201+
static constexpr auto name = const_name("Callable[..., ")
202+
+ ::pybind11::detail::return_descr(make_caster<retval_type>::name)
203+
+ const_name("]");
201204
};
202205

203206
template <typename T>

tests/test_pytypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,12 @@ TEST_SUBMODULE(pytypes, m) {
11391139
m.def("identity_iterable", [](const py::typing::Iterable<RealNumber> &x) { return x; });
11401140
// Iterator<T>
11411141
m.def("identity_iterator", [](const py::typing::Iterator<RealNumber> &x) { return x; });
1142+
// Callable<R(A)> identity
1143+
m.def("identity_callable",
1144+
[](const py::typing::Callable<RealNumber(const RealNumber &)> &x) { return x; });
1145+
// Callable<R(...)> identity
1146+
m.def("identity_callable_ellipsis",
1147+
[](const py::typing::Callable<RealNumber(py::ellipsis)> &x) { return x; });
11421148
// Callable<R(A)>
11431149
m.def("apply_callable",
11441150
[](const RealNumber &x, const py::typing::Callable<RealNumber(const RealNumber &)> &f) {

tests/test_pytypes.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,18 +1252,26 @@ def test_arg_return_type_hints(doc):
12521252
doc(m.identity_iterator)
12531253
== "identity_iterator(arg0: Iterator[Union[float, int]]) -> Iterator[float]"
12541254
)
1255+
# Callable<R(A)> identity
1256+
assert (
1257+
doc(m.identity_callable)
1258+
== "identity_callable(arg0: Callable[[Union[float, int]], float]) -> Callable[[Union[float, int]], float]"
1259+
)
1260+
# Callable<R(...)> identity
1261+
assert (
1262+
doc(m.identity_callable_ellipsis)
1263+
== "identity_callable_ellipsis(arg0: Callable[..., float]) -> Callable[..., float]"
1264+
)
12551265
# Callable<R(A)>
1256-
# TODO: Needs support for arg/return environments
1257-
# assert (
1258-
# doc(m.apply_callable)
1259-
# == "apply_callable(arg0: Union[float, int], arg1: Callable[[Union[float, int]], float]) -> float"
1260-
# )
1266+
assert (
1267+
doc(m.apply_callable)
1268+
== "apply_callable(arg0: Union[float, int], arg1: Callable[[Union[float, int]], float]) -> float"
1269+
)
12611270
# Callable<R(...)>
1262-
# TODO: Needs support for arg/return environments
1263-
# assert (
1264-
# doc(m.apply_callable_ellipsis)
1265-
# == "apply_callable_ellipsis(arg0: Union[float, int], arg1: Callable[..., float]) -> float"
1266-
# )
1271+
assert (
1272+
doc(m.apply_callable_ellipsis)
1273+
== "apply_callable_ellipsis(arg0: Union[float, int], arg1: Callable[..., float]) -> float"
1274+
)
12671275
# Union<T1, T2>
12681276
assert (
12691277
doc(m.identity_union)

0 commit comments

Comments
 (0)