Branch data Line data Source code
1 : : // types.UnionType -- used to represent e.g. Union[int, str], int | str
2 : : #include "Python.h"
3 : : #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
4 : : #include "pycore_unionobject.h"
5 : : #include "structmember.h"
6 : :
7 : :
8 : : static PyObject *make_union(PyObject *);
9 : :
10 : :
11 : : typedef struct {
12 : : PyObject_HEAD
13 : : PyObject *args;
14 : : PyObject *parameters;
15 : : } unionobject;
16 : :
17 : : static void
18 : 16 : unionobject_dealloc(PyObject *self)
19 : : {
20 : 16 : unionobject *alias = (unionobject *)self;
21 : :
22 : 16 : _PyObject_GC_UNTRACK(self);
23 : :
24 : 16 : Py_XDECREF(alias->args);
25 : 16 : Py_XDECREF(alias->parameters);
26 : 16 : Py_TYPE(self)->tp_free(self);
27 : 16 : }
28 : :
29 : : static int
30 : 128 : union_traverse(PyObject *self, visitproc visit, void *arg)
31 : : {
32 : 128 : unionobject *alias = (unionobject *)self;
33 [ + - - + ]: 128 : Py_VISIT(alias->args);
34 [ - + - - ]: 128 : Py_VISIT(alias->parameters);
35 : 128 : return 0;
36 : : }
37 : :
38 : : static Py_hash_t
39 : 0 : union_hash(PyObject *self)
40 : : {
41 : 0 : unionobject *alias = (unionobject *)self;
42 : 0 : PyObject *args = PyFrozenSet_New(alias->args);
43 [ # # ]: 0 : if (args == NULL) {
44 : 0 : return (Py_hash_t)-1;
45 : : }
46 : 0 : Py_hash_t hash = PyObject_Hash(args);
47 : 0 : Py_DECREF(args);
48 : 0 : return hash;
49 : : }
50 : :
51 : : static PyObject *
52 : 0 : union_richcompare(PyObject *a, PyObject *b, int op)
53 : : {
54 [ # # # # : 0 : if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
# # ]
55 : 0 : Py_RETURN_NOTIMPLEMENTED;
56 : : }
57 : :
58 : 0 : PyObject *a_set = PySet_New(((unionobject*)a)->args);
59 [ # # ]: 0 : if (a_set == NULL) {
60 : 0 : return NULL;
61 : : }
62 : 0 : PyObject *b_set = PySet_New(((unionobject*)b)->args);
63 [ # # ]: 0 : if (b_set == NULL) {
64 : 0 : Py_DECREF(a_set);
65 : 0 : return NULL;
66 : : }
67 : 0 : PyObject *result = PyObject_RichCompare(a_set, b_set, op);
68 : 0 : Py_DECREF(b_set);
69 : 0 : Py_DECREF(a_set);
70 : 0 : return result;
71 : : }
72 : :
73 : : static int
74 : 17 : is_same(PyObject *left, PyObject *right)
75 : : {
76 [ + + - + ]: 17 : int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
77 [ - + ]: 17 : return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
78 : : }
79 : :
80 : : static int
81 : 16 : contains(PyObject **items, Py_ssize_t size, PyObject *obj)
82 : : {
83 [ + + ]: 33 : for (int i = 0; i < size; i++) {
84 : 17 : int is_duplicate = is_same(items[i], obj);
85 [ - + ]: 17 : if (is_duplicate) { // -1 or 1
86 : 0 : return is_duplicate;
87 : : }
88 : : }
89 : 16 : return 0;
90 : : }
91 : :
92 : : static PyObject *
93 : 16 : merge(PyObject **items1, Py_ssize_t size1,
94 : : PyObject **items2, Py_ssize_t size2)
95 : : {
96 : 16 : PyObject *tuple = NULL;
97 : 16 : Py_ssize_t pos = 0;
98 : :
99 [ + + ]: 32 : for (int i = 0; i < size2; i++) {
100 : 16 : PyObject *arg = items2[i];
101 : 16 : int is_duplicate = contains(items1, size1, arg);
102 [ - + ]: 16 : if (is_duplicate < 0) {
103 : 0 : Py_XDECREF(tuple);
104 : 0 : return NULL;
105 : : }
106 [ - + ]: 16 : if (is_duplicate) {
107 : 0 : continue;
108 : : }
109 : :
110 [ + - ]: 16 : if (tuple == NULL) {
111 : 16 : tuple = PyTuple_New(size1 + size2 - i);
112 [ - + ]: 16 : if (tuple == NULL) {
113 : 0 : return NULL;
114 : : }
115 [ + + ]: 33 : for (; pos < size1; pos++) {
116 : 17 : PyObject *a = items1[pos];
117 : 17 : PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
118 : : }
119 : : }
120 : 16 : PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
121 : 16 : pos++;
122 : : }
123 : :
124 [ + - ]: 16 : if (tuple) {
125 : 16 : (void) _PyTuple_Resize(&tuple, pos);
126 : : }
127 : 16 : return tuple;
128 : : }
129 : :
130 : : static PyObject **
131 : 32 : get_types(PyObject **obj, Py_ssize_t *size)
132 : : {
133 [ + + ]: 32 : if (*obj == Py_None) {
134 : 4 : *obj = (PyObject *)&_PyNone_Type;
135 : : }
136 [ + + ]: 32 : if (_PyUnion_Check(*obj)) {
137 : 1 : PyObject *args = ((unionobject *) *obj)->args;
138 : 1 : *size = PyTuple_GET_SIZE(args);
139 : 1 : return &PyTuple_GET_ITEM(args, 0);
140 : : }
141 : : else {
142 : 31 : *size = 1;
143 : 31 : return obj;
144 : : }
145 : : }
146 : :
147 : : static int
148 : 38 : is_unionable(PyObject *obj)
149 : : {
150 [ + + ]: 34 : return (obj == Py_None ||
151 [ + + ]: 45 : PyType_Check(obj) ||
152 [ + + + + ]: 83 : _PyGenericAlias_Check(obj) ||
153 : 4 : _PyUnion_Check(obj));
154 : : }
155 : :
156 : : PyObject *
157 : 19 : _Py_union_type_or(PyObject* self, PyObject* other)
158 : : {
159 [ + - + + ]: 19 : if (!is_unionable(self) || !is_unionable(other)) {
160 : 3 : Py_RETURN_NOTIMPLEMENTED;
161 : : }
162 : :
163 : : Py_ssize_t size1, size2;
164 : 16 : PyObject **items1 = get_types(&self, &size1);
165 : 16 : PyObject **items2 = get_types(&other, &size2);
166 : 16 : PyObject *tuple = merge(items1, size1, items2, size2);
167 [ - + ]: 16 : if (tuple == NULL) {
168 [ # # ]: 0 : if (PyErr_Occurred()) {
169 : 0 : return NULL;
170 : : }
171 : 0 : return Py_NewRef(self);
172 : : }
173 : :
174 : 16 : PyObject *new_union = make_union(tuple);
175 : 16 : Py_DECREF(tuple);
176 : 16 : return new_union;
177 : : }
178 : :
179 : : static int
180 : 0 : union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
181 : : {
182 : 0 : PyObject *qualname = NULL;
183 : 0 : PyObject *module = NULL;
184 : : PyObject *tmp;
185 : 0 : PyObject *r = NULL;
186 : : int err;
187 : :
188 [ # # ]: 0 : if (p == (PyObject *)&_PyNone_Type) {
189 : 0 : return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
190 : : }
191 : :
192 [ # # ]: 0 : if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
193 : 0 : goto exit;
194 : : }
195 : :
196 [ # # ]: 0 : if (tmp) {
197 : 0 : Py_DECREF(tmp);
198 [ # # ]: 0 : if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
199 : 0 : goto exit;
200 : : }
201 [ # # ]: 0 : if (tmp) {
202 : : // It looks like a GenericAlias
203 : 0 : Py_DECREF(tmp);
204 : 0 : goto use_repr;
205 : : }
206 : : }
207 : :
208 [ # # ]: 0 : if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
209 : 0 : goto exit;
210 : : }
211 [ # # ]: 0 : if (qualname == NULL) {
212 : 0 : goto use_repr;
213 : : }
214 [ # # ]: 0 : if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
215 : 0 : goto exit;
216 : : }
217 [ # # # # ]: 0 : if (module == NULL || module == Py_None) {
218 : 0 : goto use_repr;
219 : : }
220 : :
221 : : // Looks like a class
222 [ # # # # ]: 0 : if (PyUnicode_Check(module) &&
223 : 0 : _PyUnicode_EqualToASCIIString(module, "builtins"))
224 : : {
225 : : // builtins don't need a module name
226 : 0 : r = PyObject_Str(qualname);
227 : 0 : goto exit;
228 : : }
229 : : else {
230 : 0 : r = PyUnicode_FromFormat("%S.%S", module, qualname);
231 : 0 : goto exit;
232 : : }
233 : :
234 : 0 : use_repr:
235 : 0 : r = PyObject_Repr(p);
236 : 0 : exit:
237 : 0 : Py_XDECREF(qualname);
238 : 0 : Py_XDECREF(module);
239 [ # # ]: 0 : if (r == NULL) {
240 : 0 : return -1;
241 : : }
242 : 0 : err = _PyUnicodeWriter_WriteStr(writer, r);
243 : 0 : Py_DECREF(r);
244 : 0 : return err;
245 : : }
246 : :
247 : : static PyObject *
248 : 0 : union_repr(PyObject *self)
249 : : {
250 : 0 : unionobject *alias = (unionobject *)self;
251 : 0 : Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
252 : :
253 : : _PyUnicodeWriter writer;
254 : 0 : _PyUnicodeWriter_Init(&writer);
255 [ # # ]: 0 : for (Py_ssize_t i = 0; i < len; i++) {
256 [ # # # # ]: 0 : if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
257 : 0 : goto error;
258 : : }
259 : 0 : PyObject *p = PyTuple_GET_ITEM(alias->args, i);
260 [ # # ]: 0 : if (union_repr_item(&writer, p) < 0) {
261 : 0 : goto error;
262 : : }
263 : : }
264 : 0 : return _PyUnicodeWriter_Finish(&writer);
265 : 0 : error:
266 : 0 : _PyUnicodeWriter_Dealloc(&writer);
267 : 0 : return NULL;
268 : : }
269 : :
270 : : static PyMemberDef union_members[] = {
271 : : {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
272 : : {0}
273 : : };
274 : :
275 : : static PyObject *
276 : 0 : union_getitem(PyObject *self, PyObject *item)
277 : : {
278 : 0 : unionobject *alias = (unionobject *)self;
279 : : // Populate __parameters__ if needed.
280 [ # # ]: 0 : if (alias->parameters == NULL) {
281 : 0 : alias->parameters = _Py_make_parameters(alias->args);
282 [ # # ]: 0 : if (alias->parameters == NULL) {
283 : 0 : return NULL;
284 : : }
285 : : }
286 : :
287 : 0 : PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
288 [ # # ]: 0 : if (newargs == NULL) {
289 : 0 : return NULL;
290 : : }
291 : :
292 : : PyObject *res;
293 : 0 : Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
294 [ # # ]: 0 : if (nargs == 0) {
295 : 0 : res = make_union(newargs);
296 : : }
297 : : else {
298 : 0 : res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
299 [ # # ]: 0 : for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
300 : 0 : PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
301 : 0 : Py_SETREF(res, PyNumber_Or(res, arg));
302 [ # # ]: 0 : if (res == NULL) {
303 : 0 : break;
304 : : }
305 : : }
306 : : }
307 : 0 : Py_DECREF(newargs);
308 : 0 : return res;
309 : : }
310 : :
311 : : static PyMappingMethods union_as_mapping = {
312 : : .mp_subscript = union_getitem,
313 : : };
314 : :
315 : : static PyObject *
316 : 0 : union_parameters(PyObject *self, void *Py_UNUSED(unused))
317 : : {
318 : 0 : unionobject *alias = (unionobject *)self;
319 [ # # ]: 0 : if (alias->parameters == NULL) {
320 : 0 : alias->parameters = _Py_make_parameters(alias->args);
321 [ # # ]: 0 : if (alias->parameters == NULL) {
322 : 0 : return NULL;
323 : : }
324 : : }
325 : 0 : return Py_NewRef(alias->parameters);
326 : : }
327 : :
328 : : static PyGetSetDef union_properties[] = {
329 : : {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
330 : : {0}
331 : : };
332 : :
333 : : static PyNumberMethods union_as_number = {
334 : : .nb_or = _Py_union_type_or, // Add __or__ function
335 : : };
336 : :
337 : : static const char* const cls_attrs[] = {
338 : : "__module__", // Required for compatibility with typing module
339 : : NULL,
340 : : };
341 : :
342 : : static PyObject *
343 : 0 : union_getattro(PyObject *self, PyObject *name)
344 : : {
345 : 0 : unionobject *alias = (unionobject *)self;
346 [ # # ]: 0 : if (PyUnicode_Check(name)) {
347 : 0 : for (const char * const *p = cls_attrs; ; p++) {
348 [ # # ]: 0 : if (*p == NULL) {
349 : 0 : break;
350 : : }
351 [ # # ]: 0 : if (_PyUnicode_EqualToASCIIString(name, *p)) {
352 : 0 : return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
353 : : }
354 : : }
355 : : }
356 : 0 : return PyObject_GenericGetAttr(self, name);
357 : : }
358 : :
359 : : PyObject *
360 : 0 : _Py_union_args(PyObject *self)
361 : : {
362 : : assert(_PyUnion_Check(self));
363 : 0 : return ((unionobject *) self)->args;
364 : : }
365 : :
366 : : PyTypeObject _PyUnion_Type = {
367 : : PyVarObject_HEAD_INIT(&PyType_Type, 0)
368 : : .tp_name = "types.UnionType",
369 : : .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
370 : : "\n"
371 : : "E.g. for int | str"),
372 : : .tp_basicsize = sizeof(unionobject),
373 : : .tp_dealloc = unionobject_dealloc,
374 : : .tp_alloc = PyType_GenericAlloc,
375 : : .tp_free = PyObject_GC_Del,
376 : : .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
377 : : .tp_traverse = union_traverse,
378 : : .tp_hash = union_hash,
379 : : .tp_getattro = union_getattro,
380 : : .tp_members = union_members,
381 : : .tp_richcompare = union_richcompare,
382 : : .tp_as_mapping = &union_as_mapping,
383 : : .tp_as_number = &union_as_number,
384 : : .tp_repr = union_repr,
385 : : .tp_getset = union_properties,
386 : : };
387 : :
388 : : static PyObject *
389 : 16 : make_union(PyObject *args)
390 : : {
391 : : assert(PyTuple_CheckExact(args));
392 : :
393 : 16 : unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
394 [ - + ]: 16 : if (result == NULL) {
395 : 0 : return NULL;
396 : : }
397 : :
398 : 16 : result->parameters = NULL;
399 : 16 : result->args = Py_NewRef(args);
400 : 16 : _PyObject_GC_TRACK(result);
401 : 16 : return (PyObject*)result;
402 : : }
|