diff --git a/pyproject.toml b/pyproject.toml index 50dd8ce..1759099 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ rest = ["djangorestframework>=3.9,<4.0"] [dependency-groups] dev = [ + "doc8>=1.1.2", "beautifulsoup4>=4.13.3", "coverage>=7.6.12", "darglint>=1.8.1", @@ -104,7 +105,6 @@ dev = [ "typing-extensions>=4.12.2", ] docs = [ - "doc8>=1.1.2", "docutils>=0.21.2", "furo>=2024.8.6", "readme-renderer[md]>=44.0", diff --git a/src/django_enum/fields.py b/src/django_enum/fields.py index 2f249e0..eaedfff 100644 --- a/src/django_enum/fields.py +++ b/src/django_enum/fields.py @@ -673,17 +673,20 @@ def formfield(self, form_class=None, choices_form_class=None, **kwargs): ) is_multi = self.enum and issubclass(self.enum, Flag) - if is_multi and self.enum: + if is_multi: kwargs["empty_value"] = self.enum(0) # why fail? - does this fail for single select too? # kwargs['show_hidden_initial'] = True if not self.strict: kwargs.setdefault( - "widget", NonStrictSelectMultiple if is_multi else NonStrictSelect + "widget", + NonStrictSelectMultiple(enum=self.enum) + if is_multi + else NonStrictSelect, ) elif is_multi: - kwargs.setdefault("widget", FlagSelectMultiple) + kwargs.setdefault("widget", FlagSelectMultiple(enum=self.enum)) form_field = super().formfield( form_class=form_class, diff --git a/src/django_enum/forms.py b/src/django_enum/forms.py index 04fcdba..6e7248c 100644 --- a/src/django_enum/forms.py +++ b/src/django_enum/forms.py @@ -1,8 +1,11 @@ """Enumeration support for django model forms""" +import sys from copy import copy from decimal import DecimalException from enum import Enum, Flag +from functools import reduce +from operator import or_ from typing import Any, Iterable, List, Optional, Protocol, Sequence, Tuple, Type, Union from django.core.exceptions import ValidationError @@ -85,8 +88,39 @@ class FlagSelectMultiple(SelectMultiple): A SelectMultiple widget for EnumFlagFields. """ + enum: Optional[Type[Flag]] -class NonStrictSelectMultiple(NonStrictMixin, SelectMultiple): + def __init__(self, enum: Optional[Type[Flag]] = None, **kwargs): + self.enum = enum + super().__init__(**kwargs) + + def format_value(self, value): + """ + Return a list of the flag's values. + """ + if not isinstance(value, list): + # see impl of ChoiceWidget.optgroups + # it compares the string conversion of the value of each + # choice tuple to the string conversion of the value + # to determine selected options + if self.enum: + if sys.version_info < (3, 11): + return [ + str(flg.value) + for flg in self.enum + if flg in self.enum(value) and flg is not self.enum(0) + ] + else: + return [str(en.value) for en in self.enum(value)] + if isinstance(value, int): + # automagically work for IntFlags even if we weren't given the enum + return [ + str(1 << i) for i in range(value.bit_length()) if (value >> i) & 1 + ] + return value + + +class NonStrictSelectMultiple(NonStrictMixin, FlagSelectMultiple): """ A SelectMultiple widget for non-strict EnumFlagFields that includes any existing non-conforming value as a choice option. @@ -314,6 +348,8 @@ class EnumFlagField(ChoiceFieldMixin, TypedMultipleChoiceField): # type: ignore if strict=False, values can be outside of the enumerations """ + widget = FlagSelectMultiple + def __init__( self, enum: Optional[Type[Flag]] = None, @@ -324,6 +360,10 @@ def __init__( choices: _ChoicesParameter = (), **kwargs, ): + kwargs.setdefault( + "widget", + self.widget(enum=enum) if strict else NonStrictSelectMultiple(enum=enum), + ) super().__init__( enum=enum, empty_value=( @@ -334,3 +374,12 @@ def __init__( choices=choices, **kwargs, ) + + def _coerce(self, value: Any) -> Any: + """Combine the values into a single flag using |""" + if self.enum and isinstance(value, self.enum): + return value + values = TypedMultipleChoiceField._coerce(self, value) # type: ignore[attr-defined] + if values: + return reduce(or_, values) + return self.empty_value diff --git a/tests/test_forms_ep.py b/tests/test_forms_ep.py index 29fb55d..6c36522 100644 --- a/tests/test_forms_ep.py +++ b/tests/test_forms_ep.py @@ -2,8 +2,11 @@ pytest.importorskip("enum_properties") from tests.test_forms import FormTests, TestFormField -from tests.enum_prop.models import EnumTester +from tests.enum_prop.models import EnumTester, BitFieldModel from tests.enum_prop.forms import EnumTesterForm +from tests.examples.models import FlagExample +from django_enum.forms import EnumFlagField, FlagSelectMultiple +from django.forms import ModelForm class EnumPropertiesFormTests(FormTests): @@ -34,6 +37,99 @@ def model_params(self): "no_coerce": "Value 1", } + def test_flag_choices_admin_form(self): + from django.contrib import admin + + admin_class = admin.site._registry.get(BitFieldModel) + self.assertIsInstance( + admin_class.get_form(None).base_fields.get("bit_field_small"), EnumFlagField + ) + + def test_flag_choices_model_form(self): + from tests.examples.models.flag import Permissions + from tests.enum_prop.enums import GNSSConstellation + + class FlagChoicesModelForm(ModelForm): + class Meta(EnumTesterForm.Meta): + model = BitFieldModel + + form = FlagChoicesModelForm( + data={"bit_field_small": [GNSSConstellation.GPS, GNSSConstellation.GLONASS]} + ) + + form.full_clean() + self.assertTrue(form.is_valid()) + self.assertEqual( + form.cleaned_data["bit_field_small"], + GNSSConstellation.GPS | GNSSConstellation.GLONASS, + ) + self.assertIsInstance(form.base_fields["bit_field_small"], EnumFlagField) + + def test_extern_flag_admin_form(self): + from django.contrib import admin + + admin_class = admin.site._registry.get(FlagExample) + self.assertIsInstance( + admin_class.get_form(None).base_fields.get("permissions"), EnumFlagField + ) + + def test_extern_flag_model_form(self): + from tests.examples.models.flag import Permissions + + class FlagModelForm(ModelForm): + class Meta(EnumTesterForm.Meta): + model = FlagExample + + form = FlagModelForm( + data={"permissions": [Permissions.READ, Permissions.WRITE]} + ) + + form.full_clean() + self.assertTrue(form.is_valid()) + self.assertEqual( + form.cleaned_data["permissions"], Permissions.READ | Permissions.WRITE + ) + self.assertIsInstance(form.base_fields["permissions"], EnumFlagField) + + def test_flag_select_multiple_format(self): + from tests.examples.models.flag import Permissions + + widget = FlagSelectMultiple() # no enum + self.assertEqual( + widget.format_value(Permissions.READ | Permissions.WRITE), + [str(Permissions.READ.value), str(Permissions.WRITE.value)], + ) + self.assertEqual( + widget.format_value(Permissions.READ | Permissions.EXECUTE), + [str(Permissions.READ.value), str(Permissions.EXECUTE.value)], + ) + self.assertEqual( + widget.format_value(Permissions.EXECUTE | Permissions.WRITE), + [str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)], + ) + + widget = FlagSelectMultiple(enum=Permissions) # no enum + self.assertEqual( + widget.format_value(Permissions.READ | Permissions.WRITE), + [str(Permissions.READ.value), str(Permissions.WRITE.value)], + ) + self.assertEqual( + widget.format_value(Permissions.READ | Permissions.EXECUTE), + [str(Permissions.READ.value), str(Permissions.EXECUTE.value)], + ) + self.assertEqual( + widget.format_value(Permissions.EXECUTE | Permissions.WRITE), + [str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)], + ) + + # check pass through + self.assertEqual( + widget.format_value( + [str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)] + ), + [str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)], + ) + FormTests = None TestFormField = None diff --git a/uv.lock b/uv.lock index 0692966..05bb4bf 100644 --- a/uv.lock +++ b/uv.lock @@ -575,6 +575,7 @@ dev = [ { name = "django-stubs", extra = ["compatible-mypy"] }, { name = "django-test-migrations" }, { name = "djlint" }, + { name = "doc8" }, { name = "ipdb" }, { name = "matplotlib" }, { name = "mypy" }, @@ -594,7 +595,6 @@ dev = [ { name = "typing-extensions" }, ] docs = [ - { name = "doc8" }, { name = "docutils" }, { name = "furo" }, { name = "readme-renderer", extra = ["md"] }, @@ -635,6 +635,7 @@ dev = [ { name = "django-stubs", extras = ["compatible-mypy"], specifier = ">=5.1.3" }, { name = "django-test-migrations", git = "https://github.com/bckohan/django-test-migrations.git?rev=issue-503" }, { name = "djlint", specifier = ">=1.36.4" }, + { name = "doc8", specifier = ">=1.1.2" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "matplotlib", specifier = ">=3.9.4" }, { name = "mypy", specifier = ">=1.15.0" }, @@ -654,7 +655,6 @@ dev = [ { name = "typing-extensions", specifier = ">=4.12.2" }, ] docs = [ - { name = "doc8", specifier = ">=1.1.2" }, { name = "docutils", specifier = ">=0.21.2" }, { name = "furo", specifier = ">=2024.8.6" }, { name = "readme-renderer", extras = ["md"], specifier = ">=44.0" },