Skip to content

Commit 0d44d72

Browse files
authored
Make stl.h list|set|map_caster more user friendly. (#4686)
* Add `test_pass_std_vector_int()`, `test_pass_std_set_int()` in test_stl * Change `list_caster` to also accept generator objects (`PyGen_Check(src.ptr()`). Note for completeness: This is a more conservative change than google/pybind11clif#30042 * Drop in (currently unpublished) PyCLIF code, use in `list_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdSet()` in `set_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdMap()` in `map_caster`, add tests. * Simplify `list_caster` `load()` implementation, push str/bytes check into `PyObjectTypeIsConvertibleToStdVector()`. * clang-tidy cleanup with a few extra `(... != 0)` to be more consistent. * Also use `PyObjectTypeIsConvertibleToStdVector()` in `array_caster`. * Update comment pointing to clif/python/runtime.cc (code is unchanged). * Comprehensive test coverage, enhanced set_caster load implementation. * Resolve clang-tidy eror. * Add a long C++ comment explaining what led to the `PyObjectTypeIsConvertibleTo*()` implementations. * Minor function name change in test. * strcmp -> std::strcmp (thanks @Skylion007 for catching this) * Add `PyCallable_Check(items)` in `PyObjectTypeIsConvertibleToStdMap()` * Resolve clang-tidy error * Use `PyMapping_Items()` instead of `src.attr("items")()`, to be internally consistent with `PyMapping_Check()` * Update link to PyCLIF sources. * Fix typo (thanks @wangxf123456 for catching this) * Add `test_pass_std_vector_int()`, `test_pass_std_set_int()` in test_stl * Change `list_caster` to also accept generator objects (`PyGen_Check(src.ptr()`). Note for completeness: This is a more conservative change than google/pybind11clif#30042 * Drop in (currently unpublished) PyCLIF code, use in `list_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdSet()` in `set_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdMap()` in `map_caster`, add tests. * Simplify `list_caster` `load()` implementation, push str/bytes check into `PyObjectTypeIsConvertibleToStdVector()`. * clang-tidy cleanup with a few extra `(... != 0)` to be more consistent. * Also use `PyObjectTypeIsConvertibleToStdVector()` in `array_caster`. * Update comment pointing to clif/python/runtime.cc (code is unchanged). * Comprehensive test coverage, enhanced set_caster load implementation. * Resolve clang-tidy eror. * Add a long C++ comment explaining what led to the `PyObjectTypeIsConvertibleTo*()` implementations. * Minor function name change in test. * strcmp -> std::strcmp (thanks @Skylion007 for catching this) * Add `PyCallable_Check(items)` in `PyObjectTypeIsConvertibleToStdMap()` * Resolve clang-tidy error * Use `PyMapping_Items()` instead of `src.attr("items")()`, to be internally consistent with `PyMapping_Check()` * Update link to PyCLIF sources. * Fix typo (thanks @wangxf123456 for catching this) * Fix typo discovered by new version of codespell.
1 parent 4a06eca commit 0d44d72

File tree

3 files changed

+337
-33
lines changed

3 files changed

+337
-33
lines changed

include/pybind11/stl.h

Lines changed: 177 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "detail/common.h"
1414

1515
#include <deque>
16+
#include <initializer_list>
1617
#include <list>
1718
#include <map>
1819
#include <ostream>
@@ -35,6 +36,89 @@
3536
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
3637
PYBIND11_NAMESPACE_BEGIN(detail)
3738

39+
//
40+
// Begin: Equivalent of
41+
// https://github.com/google/clif/blob/ae4eee1de07cdf115c0c9bf9fec9ff28efce6f6c/clif/python/runtime.cc#L388-L438
42+
/*
43+
The three `PyObjectTypeIsConvertibleTo*()` functions below are
44+
the result of converging the behaviors of pybind11 and PyCLIF
45+
(http://github.com/google/clif).
46+
47+
Originally PyCLIF was extremely far on the permissive side of the spectrum,
48+
while pybind11 was very far on the strict side. Originally PyCLIF accepted any
49+
Python iterable as input for a C++ `vector`/`set`/`map` argument, as long as
50+
the elements were convertible. The obvious (in hindsight) problem was that
51+
any empty Python iterable could be passed to any of these C++ types, e.g. `{}`
52+
was accepted for C++ `vector`/`set` arguments, or `[]` for C++ `map` arguments.
53+
54+
The functions below strike a practical permissive-vs-strict compromise,
55+
informed by tens of thousands of use cases in the wild. A main objective is
56+
to prevent accidents and improve readability:
57+
58+
- Python literals must match the C++ types.
59+
60+
- For C++ `set`: The potentially reducing conversion from a Python sequence
61+
(e.g. Python `list` or `tuple`) to a C++ `set` must be explicit, by going
62+
through a Python `set`.
63+
64+
- However, a Python `set` can still be passed to a C++ `vector`. The rationale
65+
is that this conversion is not reducing. Implicit conversions of this kind
66+
are also fairly commonly used, therefore enforcing explicit conversions
67+
would have an unfavorable cost : benefit ratio; more sloppily speaking,
68+
such an enforcement would be more annoying than helpful.
69+
*/
70+
71+
inline bool PyObjectIsInstanceWithOneOfTpNames(PyObject *obj,
72+
std::initializer_list<const char *> tp_names) {
73+
if (PyType_Check(obj)) {
74+
return false;
75+
}
76+
const char *obj_tp_name = Py_TYPE(obj)->tp_name;
77+
for (const auto *tp_name : tp_names) {
78+
if (std::strcmp(obj_tp_name, tp_name) == 0) {
79+
return true;
80+
}
81+
}
82+
return false;
83+
}
84+
85+
inline bool PyObjectTypeIsConvertibleToStdVector(PyObject *obj) {
86+
if (PySequence_Check(obj) != 0) {
87+
return !PyUnicode_Check(obj) && !PyBytes_Check(obj);
88+
}
89+
return (PyGen_Check(obj) != 0) || (PyAnySet_Check(obj) != 0)
90+
|| PyObjectIsInstanceWithOneOfTpNames(
91+
obj, {"dict_keys", "dict_values", "dict_items", "map", "zip"});
92+
}
93+
94+
inline bool PyObjectTypeIsConvertibleToStdSet(PyObject *obj) {
95+
return (PyAnySet_Check(obj) != 0) || PyObjectIsInstanceWithOneOfTpNames(obj, {"dict_keys"});
96+
}
97+
98+
inline bool PyObjectTypeIsConvertibleToStdMap(PyObject *obj) {
99+
if (PyDict_Check(obj)) {
100+
return true;
101+
}
102+
// Implicit requirement in the conditions below:
103+
// A type with `.__getitem__()` & `.items()` methods must implement these
104+
// to be compatible with https://docs.python.org/3/c-api/mapping.html
105+
if (PyMapping_Check(obj) == 0) {
106+
return false;
107+
}
108+
PyObject *items = PyObject_GetAttrString(obj, "items");
109+
if (items == nullptr) {
110+
PyErr_Clear();
111+
return false;
112+
}
113+
bool is_convertible = (PyCallable_Check(items) != 0);
114+
Py_DECREF(items);
115+
return is_convertible;
116+
}
117+
118+
//
119+
// End: Equivalent of clif/python/runtime.cc
120+
//
121+
38122
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
39123
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
40124
template <typename T, typename U>
@@ -66,24 +150,40 @@ struct set_caster {
66150
}
67151
void reserve_maybe(const anyset &, void *) {}
68152

69-
public:
70-
bool load(handle src, bool convert) {
71-
if (!isinstance<anyset>(src)) {
72-
return false;
73-
}
74-
auto s = reinterpret_borrow<anyset>(src);
75-
value.clear();
76-
reserve_maybe(s, &value);
77-
for (auto entry : s) {
153+
bool convert_iterable(const iterable &itbl, bool convert) {
154+
for (const auto &it : itbl) {
78155
key_conv conv;
79-
if (!conv.load(entry, convert)) {
156+
if (!conv.load(it, convert)) {
80157
return false;
81158
}
82159
value.insert(cast_op<Key &&>(std::move(conv)));
83160
}
84161
return true;
85162
}
86163

164+
bool convert_anyset(anyset s, bool convert) {
165+
value.clear();
166+
reserve_maybe(s, &value);
167+
return convert_iterable(s, convert);
168+
}
169+
170+
public:
171+
bool load(handle src, bool convert) {
172+
if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) {
173+
return false;
174+
}
175+
if (isinstance<anyset>(src)) {
176+
value.clear();
177+
return convert_anyset(reinterpret_borrow<anyset>(src), convert);
178+
}
179+
if (!convert) {
180+
return false;
181+
}
182+
assert(isinstance<iterable>(src));
183+
value.clear();
184+
return convert_iterable(reinterpret_borrow<iterable>(src), convert);
185+
}
186+
87187
template <typename T>
88188
static handle cast(T &&src, return_value_policy policy, handle parent) {
89189
if (!std::is_lvalue_reference<T>::value) {
@@ -115,15 +215,10 @@ struct map_caster {
115215
}
116216
void reserve_maybe(const dict &, void *) {}
117217

118-
public:
119-
bool load(handle src, bool convert) {
120-
if (!isinstance<dict>(src)) {
121-
return false;
122-
}
123-
auto d = reinterpret_borrow<dict>(src);
218+
bool convert_elements(const dict &d, bool convert) {
124219
value.clear();
125220
reserve_maybe(d, &value);
126-
for (auto it : d) {
221+
for (const auto &it : d) {
127222
key_conv kconv;
128223
value_conv vconv;
129224
if (!kconv.load(it.first.ptr(), convert) || !vconv.load(it.second.ptr(), convert)) {
@@ -134,6 +229,25 @@ struct map_caster {
134229
return true;
135230
}
136231

232+
public:
233+
bool load(handle src, bool convert) {
234+
if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) {
235+
return false;
236+
}
237+
if (isinstance<dict>(src)) {
238+
return convert_elements(reinterpret_borrow<dict>(src), convert);
239+
}
240+
if (!convert) {
241+
return false;
242+
}
243+
auto items = reinterpret_steal<object>(PyMapping_Items(src.ptr()));
244+
if (!items) {
245+
throw error_already_set();
246+
}
247+
assert(isinstance<iterable>(items));
248+
return convert_elements(dict(reinterpret_borrow<iterable>(items)), convert);
249+
}
250+
137251
template <typename T>
138252
static handle cast(T &&src, return_value_policy policy, handle parent) {
139253
dict d;
@@ -166,13 +280,35 @@ struct list_caster {
166280
using value_conv = make_caster<Value>;
167281

168282
bool load(handle src, bool convert) {
169-
if (!isinstance<sequence>(src) || isinstance<bytes>(src) || isinstance<str>(src)) {
283+
if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) {
284+
return false;
285+
}
286+
if (isinstance<sequence>(src)) {
287+
return convert_elements(src, convert);
288+
}
289+
if (!convert) {
170290
return false;
171291
}
172-
auto s = reinterpret_borrow<sequence>(src);
292+
// Designed to be behavior-equivalent to passing tuple(src) from Python:
293+
// The conversion to a tuple will first exhaust the generator object, to ensure that
294+
// the generator is not left in an unpredictable (to the caller) partially-consumed
295+
// state.
296+
assert(isinstance<iterable>(src));
297+
return convert_elements(tuple(reinterpret_borrow<iterable>(src)), convert);
298+
}
299+
300+
private:
301+
template <typename T = Type, enable_if_t<has_reserve_method<T>::value, int> = 0>
302+
void reserve_maybe(const sequence &s, Type *) {
303+
value.reserve(s.size());
304+
}
305+
void reserve_maybe(const sequence &, void *) {}
306+
307+
bool convert_elements(handle seq, bool convert) {
308+
auto s = reinterpret_borrow<sequence>(seq);
173309
value.clear();
174310
reserve_maybe(s, &value);
175-
for (const auto &it : s) {
311+
for (const auto &it : seq) {
176312
value_conv conv;
177313
if (!conv.load(it, convert)) {
178314
return false;
@@ -182,13 +318,6 @@ struct list_caster {
182318
return true;
183319
}
184320

185-
private:
186-
template <typename T = Type, enable_if_t<has_reserve_method<T>::value, int> = 0>
187-
void reserve_maybe(const sequence &s, Type *) {
188-
value.reserve(s.size());
189-
}
190-
void reserve_maybe(const sequence &, void *) {}
191-
192321
public:
193322
template <typename T>
194323
static handle cast(T &&src, return_value_policy policy, handle parent) {
@@ -237,12 +366,8 @@ struct array_caster {
237366
return size == Size;
238367
}
239368

240-
public:
241-
bool load(handle src, bool convert) {
242-
if (!isinstance<sequence>(src)) {
243-
return false;
244-
}
245-
auto l = reinterpret_borrow<sequence>(src);
369+
bool convert_elements(handle seq, bool convert) {
370+
auto l = reinterpret_borrow<sequence>(seq);
246371
if (!require_size(l.size())) {
247372
return false;
248373
}
@@ -257,6 +382,25 @@ struct array_caster {
257382
return true;
258383
}
259384

385+
public:
386+
bool load(handle src, bool convert) {
387+
if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) {
388+
return false;
389+
}
390+
if (isinstance<sequence>(src)) {
391+
return convert_elements(src, convert);
392+
}
393+
if (!convert) {
394+
return false;
395+
}
396+
// Designed to be behavior-equivalent to passing tuple(src) from Python:
397+
// The conversion to a tuple will first exhaust the generator object, to ensure that
398+
// the generator is not left in an unpredictable (to the caller) partially-consumed
399+
// state.
400+
assert(isinstance<iterable>(src));
401+
return convert_elements(tuple(reinterpret_borrow<iterable>(src)), convert);
402+
}
403+
260404
template <typename T>
261405
static handle cast(T &&src, return_value_policy policy, handle parent) {
262406
list l(src.size());

tests/test_stl.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ struct type_caster<ReferenceSensitiveOptional<T>>
167167
} // namespace detail
168168
} // namespace PYBIND11_NAMESPACE
169169

170+
int pass_std_vector_int(const std::vector<int> &v) {
171+
int zum = 100;
172+
for (const int i : v) {
173+
zum += 2 * i;
174+
}
175+
return zum;
176+
}
177+
170178
TEST_SUBMODULE(stl, m) {
171179
// test_vector
172180
m.def("cast_vector", []() { return std::vector<int>{1}; });
@@ -546,4 +554,30 @@ TEST_SUBMODULE(stl, m) {
546554
[]() { return new std::vector<bool>(4513); },
547555
// Without explicitly specifying `take_ownership`, this function leaks.
548556
py::return_value_policy::take_ownership);
557+
558+
m.def("pass_std_vector_int", pass_std_vector_int);
559+
m.def("pass_std_vector_pair_int", [](const std::vector<std::pair<int, int>> &v) {
560+
int zum = 0;
561+
for (const auto &ij : v) {
562+
zum += ij.first * 100 + ij.second;
563+
}
564+
return zum;
565+
});
566+
m.def("pass_std_array_int_2", [](const std::array<int, 2> &a) {
567+
return pass_std_vector_int(std::vector<int>(a.begin(), a.end())) + 1;
568+
});
569+
m.def("pass_std_set_int", [](const std::set<int> &s) {
570+
int zum = 200;
571+
for (const int i : s) {
572+
zum += 3 * i;
573+
}
574+
return zum;
575+
});
576+
m.def("pass_std_map_int", [](const std::map<int, int> &m) {
577+
int zum = 500;
578+
for (const auto &p : m) {
579+
zum += p.first * 1000 + p.second;
580+
}
581+
return zum;
582+
});
549583
}

0 commit comments

Comments
 (0)