LCOV - code coverage report
Current view: top level - Objects - unionobject.c (source / functions) Hit Total Coverage
Test: CPython 3.12 LCOV report [commit 5e6661bce9] Lines: 70 192 36.5 %
Date: 2023-03-20 08:15:36 Functions: 9 17 52.9 %
Branches: 36 122 29.5 %

           Branch data     Line data    Source code
       1                 :            : // types.UnionType -- used to represent e.g. Union[int, str], int | str
       2                 :            : #include "Python.h"
       3                 :            : #include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
       4                 :            : #include "pycore_unionobject.h"
       5                 :            : #include "structmember.h"
       6                 :            : 
       7                 :            : 
       8                 :            : static PyObject *make_union(PyObject *);
       9                 :            : 
      10                 :            : 
      11                 :            : typedef struct {
      12                 :            :     PyObject_HEAD
      13                 :            :     PyObject *args;
      14                 :            :     PyObject *parameters;
      15                 :            : } unionobject;
      16                 :            : 
      17                 :            : static void
      18                 :         16 : unionobject_dealloc(PyObject *self)
      19                 :            : {
      20                 :         16 :     unionobject *alias = (unionobject *)self;
      21                 :            : 
      22                 :         16 :     _PyObject_GC_UNTRACK(self);
      23                 :            : 
      24                 :         16 :     Py_XDECREF(alias->args);
      25                 :         16 :     Py_XDECREF(alias->parameters);
      26                 :         16 :     Py_TYPE(self)->tp_free(self);
      27                 :         16 : }
      28                 :            : 
      29                 :            : static int
      30                 :        128 : union_traverse(PyObject *self, visitproc visit, void *arg)
      31                 :            : {
      32                 :        128 :     unionobject *alias = (unionobject *)self;
      33   [ +  -  -  + ]:        128 :     Py_VISIT(alias->args);
      34   [ -  +  -  - ]:        128 :     Py_VISIT(alias->parameters);
      35                 :        128 :     return 0;
      36                 :            : }
      37                 :            : 
      38                 :            : static Py_hash_t
      39                 :          0 : union_hash(PyObject *self)
      40                 :            : {
      41                 :          0 :     unionobject *alias = (unionobject *)self;
      42                 :          0 :     PyObject *args = PyFrozenSet_New(alias->args);
      43         [ #  # ]:          0 :     if (args == NULL) {
      44                 :          0 :         return (Py_hash_t)-1;
      45                 :            :     }
      46                 :          0 :     Py_hash_t hash = PyObject_Hash(args);
      47                 :          0 :     Py_DECREF(args);
      48                 :          0 :     return hash;
      49                 :            : }
      50                 :            : 
      51                 :            : static PyObject *
      52                 :          0 : union_richcompare(PyObject *a, PyObject *b, int op)
      53                 :            : {
      54   [ #  #  #  #  :          0 :     if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
                   #  # ]
      55                 :          0 :         Py_RETURN_NOTIMPLEMENTED;
      56                 :            :     }
      57                 :            : 
      58                 :          0 :     PyObject *a_set = PySet_New(((unionobject*)a)->args);
      59         [ #  # ]:          0 :     if (a_set == NULL) {
      60                 :          0 :         return NULL;
      61                 :            :     }
      62                 :          0 :     PyObject *b_set = PySet_New(((unionobject*)b)->args);
      63         [ #  # ]:          0 :     if (b_set == NULL) {
      64                 :          0 :         Py_DECREF(a_set);
      65                 :          0 :         return NULL;
      66                 :            :     }
      67                 :          0 :     PyObject *result = PyObject_RichCompare(a_set, b_set, op);
      68                 :          0 :     Py_DECREF(b_set);
      69                 :          0 :     Py_DECREF(a_set);
      70                 :          0 :     return result;
      71                 :            : }
      72                 :            : 
      73                 :            : static int
      74                 :         17 : is_same(PyObject *left, PyObject *right)
      75                 :            : {
      76   [ +  +  -  + ]:         17 :     int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
      77         [ -  + ]:         17 :     return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
      78                 :            : }
      79                 :            : 
      80                 :            : static int
      81                 :         16 : contains(PyObject **items, Py_ssize_t size, PyObject *obj)
      82                 :            : {
      83         [ +  + ]:         33 :     for (int i = 0; i < size; i++) {
      84                 :         17 :         int is_duplicate = is_same(items[i], obj);
      85         [ -  + ]:         17 :         if (is_duplicate) {  // -1 or 1
      86                 :          0 :             return is_duplicate;
      87                 :            :         }
      88                 :            :     }
      89                 :         16 :     return 0;
      90                 :            : }
      91                 :            : 
      92                 :            : static PyObject *
      93                 :         16 : merge(PyObject **items1, Py_ssize_t size1,
      94                 :            :       PyObject **items2, Py_ssize_t size2)
      95                 :            : {
      96                 :         16 :     PyObject *tuple = NULL;
      97                 :         16 :     Py_ssize_t pos = 0;
      98                 :            : 
      99         [ +  + ]:         32 :     for (int i = 0; i < size2; i++) {
     100                 :         16 :         PyObject *arg = items2[i];
     101                 :         16 :         int is_duplicate = contains(items1, size1, arg);
     102         [ -  + ]:         16 :         if (is_duplicate < 0) {
     103                 :          0 :             Py_XDECREF(tuple);
     104                 :          0 :             return NULL;
     105                 :            :         }
     106         [ -  + ]:         16 :         if (is_duplicate) {
     107                 :          0 :             continue;
     108                 :            :         }
     109                 :            : 
     110         [ +  - ]:         16 :         if (tuple == NULL) {
     111                 :         16 :             tuple = PyTuple_New(size1 + size2 - i);
     112         [ -  + ]:         16 :             if (tuple == NULL) {
     113                 :          0 :                 return NULL;
     114                 :            :             }
     115         [ +  + ]:         33 :             for (; pos < size1; pos++) {
     116                 :         17 :                 PyObject *a = items1[pos];
     117                 :         17 :                 PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
     118                 :            :             }
     119                 :            :         }
     120                 :         16 :         PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
     121                 :         16 :         pos++;
     122                 :            :     }
     123                 :            : 
     124         [ +  - ]:         16 :     if (tuple) {
     125                 :         16 :         (void) _PyTuple_Resize(&tuple, pos);
     126                 :            :     }
     127                 :         16 :     return tuple;
     128                 :            : }
     129                 :            : 
     130                 :            : static PyObject **
     131                 :         32 : get_types(PyObject **obj, Py_ssize_t *size)
     132                 :            : {
     133         [ +  + ]:         32 :     if (*obj == Py_None) {
     134                 :          4 :         *obj = (PyObject *)&_PyNone_Type;
     135                 :            :     }
     136         [ +  + ]:         32 :     if (_PyUnion_Check(*obj)) {
     137                 :          1 :         PyObject *args = ((unionobject *) *obj)->args;
     138                 :          1 :         *size = PyTuple_GET_SIZE(args);
     139                 :          1 :         return &PyTuple_GET_ITEM(args, 0);
     140                 :            :     }
     141                 :            :     else {
     142                 :         31 :         *size = 1;
     143                 :         31 :         return obj;
     144                 :            :     }
     145                 :            : }
     146                 :            : 
     147                 :            : static int
     148                 :         38 : is_unionable(PyObject *obj)
     149                 :            : {
     150         [ +  + ]:         34 :     return (obj == Py_None ||
     151         [ +  + ]:         45 :         PyType_Check(obj) ||
     152   [ +  +  +  + ]:         83 :         _PyGenericAlias_Check(obj) ||
     153                 :          4 :         _PyUnion_Check(obj));
     154                 :            : }
     155                 :            : 
     156                 :            : PyObject *
     157                 :         19 : _Py_union_type_or(PyObject* self, PyObject* other)
     158                 :            : {
     159   [ +  -  +  + ]:         19 :     if (!is_unionable(self) || !is_unionable(other)) {
     160                 :          3 :         Py_RETURN_NOTIMPLEMENTED;
     161                 :            :     }
     162                 :            : 
     163                 :            :     Py_ssize_t size1, size2;
     164                 :         16 :     PyObject **items1 = get_types(&self, &size1);
     165                 :         16 :     PyObject **items2 = get_types(&other, &size2);
     166                 :         16 :     PyObject *tuple = merge(items1, size1, items2, size2);
     167         [ -  + ]:         16 :     if (tuple == NULL) {
     168         [ #  # ]:          0 :         if (PyErr_Occurred()) {
     169                 :          0 :             return NULL;
     170                 :            :         }
     171                 :          0 :         return Py_NewRef(self);
     172                 :            :     }
     173                 :            : 
     174                 :         16 :     PyObject *new_union = make_union(tuple);
     175                 :         16 :     Py_DECREF(tuple);
     176                 :         16 :     return new_union;
     177                 :            : }
     178                 :            : 
     179                 :            : static int
     180                 :          0 : union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
     181                 :            : {
     182                 :          0 :     PyObject *qualname = NULL;
     183                 :          0 :     PyObject *module = NULL;
     184                 :            :     PyObject *tmp;
     185                 :          0 :     PyObject *r = NULL;
     186                 :            :     int err;
     187                 :            : 
     188         [ #  # ]:          0 :     if (p == (PyObject *)&_PyNone_Type) {
     189                 :          0 :         return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
     190                 :            :     }
     191                 :            : 
     192         [ #  # ]:          0 :     if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
     193                 :          0 :         goto exit;
     194                 :            :     }
     195                 :            : 
     196         [ #  # ]:          0 :     if (tmp) {
     197                 :          0 :         Py_DECREF(tmp);
     198         [ #  # ]:          0 :         if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
     199                 :          0 :             goto exit;
     200                 :            :         }
     201         [ #  # ]:          0 :         if (tmp) {
     202                 :            :             // It looks like a GenericAlias
     203                 :          0 :             Py_DECREF(tmp);
     204                 :          0 :             goto use_repr;
     205                 :            :         }
     206                 :            :     }
     207                 :            : 
     208         [ #  # ]:          0 :     if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
     209                 :          0 :         goto exit;
     210                 :            :     }
     211         [ #  # ]:          0 :     if (qualname == NULL) {
     212                 :          0 :         goto use_repr;
     213                 :            :     }
     214         [ #  # ]:          0 :     if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
     215                 :          0 :         goto exit;
     216                 :            :     }
     217   [ #  #  #  # ]:          0 :     if (module == NULL || module == Py_None) {
     218                 :          0 :         goto use_repr;
     219                 :            :     }
     220                 :            : 
     221                 :            :     // Looks like a class
     222   [ #  #  #  # ]:          0 :     if (PyUnicode_Check(module) &&
     223                 :          0 :         _PyUnicode_EqualToASCIIString(module, "builtins"))
     224                 :            :     {
     225                 :            :         // builtins don't need a module name
     226                 :          0 :         r = PyObject_Str(qualname);
     227                 :          0 :         goto exit;
     228                 :            :     }
     229                 :            :     else {
     230                 :          0 :         r = PyUnicode_FromFormat("%S.%S", module, qualname);
     231                 :          0 :         goto exit;
     232                 :            :     }
     233                 :            : 
     234                 :          0 : use_repr:
     235                 :          0 :     r = PyObject_Repr(p);
     236                 :          0 : exit:
     237                 :          0 :     Py_XDECREF(qualname);
     238                 :          0 :     Py_XDECREF(module);
     239         [ #  # ]:          0 :     if (r == NULL) {
     240                 :          0 :         return -1;
     241                 :            :     }
     242                 :          0 :     err = _PyUnicodeWriter_WriteStr(writer, r);
     243                 :          0 :     Py_DECREF(r);
     244                 :          0 :     return err;
     245                 :            : }
     246                 :            : 
     247                 :            : static PyObject *
     248                 :          0 : union_repr(PyObject *self)
     249                 :            : {
     250                 :          0 :     unionobject *alias = (unionobject *)self;
     251                 :          0 :     Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
     252                 :            : 
     253                 :            :     _PyUnicodeWriter writer;
     254                 :          0 :     _PyUnicodeWriter_Init(&writer);
     255         [ #  # ]:          0 :      for (Py_ssize_t i = 0; i < len; i++) {
     256   [ #  #  #  # ]:          0 :         if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
     257                 :          0 :             goto error;
     258                 :            :         }
     259                 :          0 :         PyObject *p = PyTuple_GET_ITEM(alias->args, i);
     260         [ #  # ]:          0 :         if (union_repr_item(&writer, p) < 0) {
     261                 :          0 :             goto error;
     262                 :            :         }
     263                 :            :     }
     264                 :          0 :     return _PyUnicodeWriter_Finish(&writer);
     265                 :          0 : error:
     266                 :          0 :     _PyUnicodeWriter_Dealloc(&writer);
     267                 :          0 :     return NULL;
     268                 :            : }
     269                 :            : 
     270                 :            : static PyMemberDef union_members[] = {
     271                 :            :         {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
     272                 :            :         {0}
     273                 :            : };
     274                 :            : 
     275                 :            : static PyObject *
     276                 :          0 : union_getitem(PyObject *self, PyObject *item)
     277                 :            : {
     278                 :          0 :     unionobject *alias = (unionobject *)self;
     279                 :            :     // Populate __parameters__ if needed.
     280         [ #  # ]:          0 :     if (alias->parameters == NULL) {
     281                 :          0 :         alias->parameters = _Py_make_parameters(alias->args);
     282         [ #  # ]:          0 :         if (alias->parameters == NULL) {
     283                 :          0 :             return NULL;
     284                 :            :         }
     285                 :            :     }
     286                 :            : 
     287                 :          0 :     PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
     288         [ #  # ]:          0 :     if (newargs == NULL) {
     289                 :          0 :         return NULL;
     290                 :            :     }
     291                 :            : 
     292                 :            :     PyObject *res;
     293                 :          0 :     Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
     294         [ #  # ]:          0 :     if (nargs == 0) {
     295                 :          0 :         res = make_union(newargs);
     296                 :            :     }
     297                 :            :     else {
     298                 :          0 :         res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
     299         [ #  # ]:          0 :         for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
     300                 :          0 :             PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
     301                 :          0 :             Py_SETREF(res, PyNumber_Or(res, arg));
     302         [ #  # ]:          0 :             if (res == NULL) {
     303                 :          0 :                 break;
     304                 :            :             }
     305                 :            :         }
     306                 :            :     }
     307                 :          0 :     Py_DECREF(newargs);
     308                 :          0 :     return res;
     309                 :            : }
     310                 :            : 
     311                 :            : static PyMappingMethods union_as_mapping = {
     312                 :            :     .mp_subscript = union_getitem,
     313                 :            : };
     314                 :            : 
     315                 :            : static PyObject *
     316                 :          0 : union_parameters(PyObject *self, void *Py_UNUSED(unused))
     317                 :            : {
     318                 :          0 :     unionobject *alias = (unionobject *)self;
     319         [ #  # ]:          0 :     if (alias->parameters == NULL) {
     320                 :          0 :         alias->parameters = _Py_make_parameters(alias->args);
     321         [ #  # ]:          0 :         if (alias->parameters == NULL) {
     322                 :          0 :             return NULL;
     323                 :            :         }
     324                 :            :     }
     325                 :          0 :     return Py_NewRef(alias->parameters);
     326                 :            : }
     327                 :            : 
     328                 :            : static PyGetSetDef union_properties[] = {
     329                 :            :     {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
     330                 :            :     {0}
     331                 :            : };
     332                 :            : 
     333                 :            : static PyNumberMethods union_as_number = {
     334                 :            :         .nb_or = _Py_union_type_or, // Add __or__ function
     335                 :            : };
     336                 :            : 
     337                 :            : static const char* const cls_attrs[] = {
     338                 :            :         "__module__",  // Required for compatibility with typing module
     339                 :            :         NULL,
     340                 :            : };
     341                 :            : 
     342                 :            : static PyObject *
     343                 :          0 : union_getattro(PyObject *self, PyObject *name)
     344                 :            : {
     345                 :          0 :     unionobject *alias = (unionobject *)self;
     346         [ #  # ]:          0 :     if (PyUnicode_Check(name)) {
     347                 :          0 :         for (const char * const *p = cls_attrs; ; p++) {
     348         [ #  # ]:          0 :             if (*p == NULL) {
     349                 :          0 :                 break;
     350                 :            :             }
     351         [ #  # ]:          0 :             if (_PyUnicode_EqualToASCIIString(name, *p)) {
     352                 :          0 :                 return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
     353                 :            :             }
     354                 :            :         }
     355                 :            :     }
     356                 :          0 :     return PyObject_GenericGetAttr(self, name);
     357                 :            : }
     358                 :            : 
     359                 :            : PyObject *
     360                 :          0 : _Py_union_args(PyObject *self)
     361                 :            : {
     362                 :            :     assert(_PyUnion_Check(self));
     363                 :          0 :     return ((unionobject *) self)->args;
     364                 :            : }
     365                 :            : 
     366                 :            : PyTypeObject _PyUnion_Type = {
     367                 :            :     PyVarObject_HEAD_INIT(&PyType_Type, 0)
     368                 :            :     .tp_name = "types.UnionType",
     369                 :            :     .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
     370                 :            :               "\n"
     371                 :            :               "E.g. for int | str"),
     372                 :            :     .tp_basicsize = sizeof(unionobject),
     373                 :            :     .tp_dealloc = unionobject_dealloc,
     374                 :            :     .tp_alloc = PyType_GenericAlloc,
     375                 :            :     .tp_free = PyObject_GC_Del,
     376                 :            :     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
     377                 :            :     .tp_traverse = union_traverse,
     378                 :            :     .tp_hash = union_hash,
     379                 :            :     .tp_getattro = union_getattro,
     380                 :            :     .tp_members = union_members,
     381                 :            :     .tp_richcompare = union_richcompare,
     382                 :            :     .tp_as_mapping = &union_as_mapping,
     383                 :            :     .tp_as_number = &union_as_number,
     384                 :            :     .tp_repr = union_repr,
     385                 :            :     .tp_getset = union_properties,
     386                 :            : };
     387                 :            : 
     388                 :            : static PyObject *
     389                 :         16 : make_union(PyObject *args)
     390                 :            : {
     391                 :            :     assert(PyTuple_CheckExact(args));
     392                 :            : 
     393                 :         16 :     unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
     394         [ -  + ]:         16 :     if (result == NULL) {
     395                 :          0 :         return NULL;
     396                 :            :     }
     397                 :            : 
     398                 :         16 :     result->parameters = NULL;
     399                 :         16 :     result->args = Py_NewRef(args);
     400                 :         16 :     _PyObject_GC_TRACK(result);
     401                 :         16 :     return (PyObject*)result;
     402                 :            : }

Generated by: LCOV version 1.14