Skip to content

Commit 9b1786d

Browse files
Restore Django 5.0+ model_field_choices fixer (#542)
Fixes #417. Restore the fixer with a new restriction that we only act when the choices class is defined in the same file and can be reliably detected as subclassing one of Django's enumeration types. Co-authored-by: Thibaut Decombe <[email protected]>
1 parent 964b6fe commit 9b1786d

File tree

4 files changed

+525
-0
lines changed

4 files changed

+525
-0
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ Changelog
66

77
* Add Django 5.2+ fixer ``staticfiles_find_all`` to rewrite calls to the staticfiles ``find()`` function using the old argument name ``all`` to the new name ``find_all``.
88

9+
* Restore Django 5.0+ fixer ``model_field_choices``, with a restriction to only apply when the enumeration type is defined in the same file.
10+
11+
Thanks to Thibaut Decombe for initially contributing it in `PR #369 <https://github.com/adamchainz/django-upgrade/pull/369>`__.
12+
913
* Add Django 2.0+ compatibility import rewrite of `django.http.cookie.SimpleCookie` to `http.cookies.SimpleCookie`.
1014

1115
Thanks to Thibaut Decombe in `PR #537 <https://github.com/adamchainz/django-upgrade/pull/537>`__.

README.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,26 @@ Django 5.0
333333

334334
`Release Notes <https://docs.djangoproject.com/en/5.0/releases/5.0/>`__
335335

336+
Model field enumeration type ``.choices``
337+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
338+
339+
**Name:** ``model_field_choices``
340+
341+
Drop ``.choices`` for model field ``choices`` parameters, where the value is a class defined in the same file.
342+
This change is possible because the ``choices`` parameter now accepts enumeration types directly.
343+
344+
.. code-block:: diff
345+
346+
from django.db import models
347+
348+
class Suit(models.IntegerChoices):
349+
HEARTS = 1
350+
...
351+
352+
class Card(models.Model):
353+
- suit = models.IntegerField(choices=Suit.choices, default=Suit.DEFAULT)
354+
+ suit = models.IntegerField(choices=Suit, default=Suit.DEFAULT)
355+
336356
``format_html()`` calls
337357
~~~~~~~~~~~~~~~~~~~~~~~
338358

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
Drop `.choices` for model field `choices` parameters:
3+
https://docs.djangoproject.com/en/5.0/releases/5.0/#forms
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import ast
9+
from collections import defaultdict
10+
from collections.abc import Iterable
11+
from functools import partial
12+
from typing import DefaultDict
13+
from typing import cast
14+
from weakref import WeakKeyDictionary
15+
16+
from tokenize_rt import Offset
17+
from tokenize_rt import Token
18+
19+
from django_upgrade.ast import ast_start_offset
20+
from django_upgrade.data import Fixer
21+
from django_upgrade.data import State
22+
from django_upgrade.data import TokenFunc
23+
from django_upgrade.tokens import OP
24+
from django_upgrade.tokens import find_last_token
25+
from django_upgrade.tokens import reverse_find
26+
27+
fixer = Fixer(
28+
__name__,
29+
min_version=(5, 0),
30+
)
31+
32+
# Cache defined enumeration types by module
33+
module_defined_enumeration_types: WeakKeyDictionary[ast.Module, dict[str, int]] = (
34+
WeakKeyDictionary()
35+
)
36+
37+
38+
def defined_enumeration_types(module: ast.Module, up_to_line: int) -> set[str]:
39+
"""
40+
Return a set of enumeration type class names defined in the given module, up to a line number.
41+
"""
42+
if module not in module_defined_enumeration_types:
43+
enum_dict = {}
44+
from_imports: DefaultDict[str, set[str]] = defaultdict(set)
45+
for node in module.body:
46+
if (
47+
isinstance(node, ast.ImportFrom)
48+
and node.level == 0
49+
and node.module is not None
50+
):
51+
from_imports[node.module].update(
52+
name.name
53+
for name in node.names
54+
if name.asname is None and name.name != "*"
55+
)
56+
elif isinstance(node, ast.ClassDef):
57+
# Check if the class inherits from one of Django's choice types
58+
for base in node.bases:
59+
if _is_django_choices_type(from_imports, base):
60+
enum_dict[node.name] = node.lineno
61+
break
62+
module_defined_enumeration_types[module] = enum_dict
63+
64+
return {
65+
name
66+
for name, line in module_defined_enumeration_types[module].items()
67+
if line <= up_to_line
68+
}
69+
70+
71+
DJANGO_CHOICES_TYPES = {
72+
"TextChoices",
73+
"IntegerChoices",
74+
"Choices",
75+
}
76+
77+
78+
def _is_django_choices_type(
79+
from_imports: DefaultDict[str, set[str]], node: ast.expr
80+
) -> bool:
81+
"""Check if an AST node refers to a Django enumeration type base class."""
82+
return (
83+
isinstance(node, ast.Name)
84+
and node.id in DJANGO_CHOICES_TYPES
85+
and (
86+
node.id in from_imports["django.db.models"]
87+
or node.id in from_imports["django.db.models.enums"]
88+
)
89+
) or (
90+
isinstance(node, ast.Attribute)
91+
and node.attr in DJANGO_CHOICES_TYPES
92+
and isinstance(node.value, ast.Name)
93+
and (
94+
(node.value.id == "models" and node.value.id in from_imports["django.db"])
95+
or (
96+
node.value.id == "enums"
97+
and node.value.id in from_imports["django.db.models"]
98+
)
99+
)
100+
)
101+
102+
103+
@fixer.register(ast.Call)
104+
def visit_Call(
105+
state: State,
106+
node: ast.Call,
107+
parents: tuple[ast.AST, ...],
108+
) -> Iterable[tuple[Offset, TokenFunc]]:
109+
if (
110+
state.looks_like_models_file
111+
and (
112+
(
113+
isinstance(node.func, ast.Attribute)
114+
and isinstance(node.func.value, ast.Name)
115+
and node.func.attr.endswith("Field")
116+
)
117+
or (isinstance(node.func, ast.Name) and node.func.id.endswith("Field"))
118+
)
119+
and any(
120+
kw.arg == "choices"
121+
and isinstance(kw.value, ast.Attribute)
122+
and (target_node := kw.value).attr == "choices"
123+
and isinstance(target_node.value, ast.Name)
124+
and (
125+
target_node.value.id
126+
in defined_enumeration_types(
127+
cast(ast.Module, parents[0]),
128+
node.lineno,
129+
)
130+
)
131+
for kw in node.keywords
132+
)
133+
):
134+
yield ast_start_offset(target_node), partial(remove_choices, node=target_node)
135+
136+
137+
def remove_choices(tokens: list[Token], i: int, node: ast.Attribute) -> None:
138+
j = find_last_token(tokens, i, node=node)
139+
i = reverse_find(tokens, j, name=OP, src=".")
140+
del tokens[i : j + 1]

0 commit comments

Comments
 (0)