diff --git a/CHANGELOG.md b/CHANGELOG.md index 639b150..43b2033 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,9 +4,11 @@ Features: +- Add `enum` parser ([#185](https://github.com/sloria/environs/pull/185)). - Add `delimiter` param to `env.list` ([#184](https://github.com/sloria/environs/pull/184)). - Thanks [tomgrin10](https://github.com/tomgrin10?) for the PR. + +Thanks [tomgrin10](https://github.com/tomgrin10?) for the PRs. Bug fixes: diff --git a/README.md b/README.md index 0d17f67..cc6c0dd 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,7 @@ The following are all type-casting methods of `Env`: - `env.uuid` - `env.log_level` - `env.path` (casts to a [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html)) +- `env.enum` (casts to any given enum type specified in `type` keyword argument, accepts optional `ignore_case` keyword argument) ## Reading `.env` files @@ -280,14 +281,14 @@ domain = env.furl("DOMAIN") # => furl('https://myapp.com') # Custom parsers can take extra keyword arguments -@env.parser_for("enum") -def enum_parser(value, choices): +@env.parser_for("choice") +def choice_parser(value, choices): if value not in choices: raise environs.EnvError("Invalid!") return value -color = env.enum("COLOR", choices=["black"]) # => raises EnvError +color = env.choice("COLOR", choices=["black"]) # => raises EnvError ``` ## Usage with Flask diff --git a/environs/__init__.py b/environs/__init__.py index 0d42731..a42b8db 100644 --- a/environs/__init__.py +++ b/environs/__init__.py @@ -9,6 +9,7 @@ import typing import types from collections.abc import Mapping +from enum import Enum from urllib.parse import urlparse, ParseResult from pathlib import Path @@ -177,6 +178,25 @@ def _preprocess_json(value: str, **kwargs): return pyjson.loads(value) +_EnumT = typing.TypeVar("_EnumT", bound=Enum) + + +def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) -> _EnumT: + invalid_exc = ma.ValidationError(f"Not a valid '{type.__name__}' enum.") + + if not ignore_case: + try: + return type[value] + except KeyError as error: + raise invalid_exc from error + + for enum_value in type: + if enum_value.name.lower() == value.lower(): + return enum_value + + raise invalid_exc + + def _dj_db_url_parser(value: str, **kwargs) -> dict: try: import dj_database_url @@ -276,6 +296,7 @@ class Env: timedelta = _field2method(ma.fields.TimeDelta, "timedelta") uuid = _field2method(ma.fields.UUID, "uuid") url = _field2method(URLField, "url") + enum = _func2method(_enum_parser, "enum") dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url") dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url") dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url") diff --git a/tests/test_environs.py b/tests/test_environs.py index 7aee702..c2451c5 100644 --- a/tests/test_environs.py +++ b/tests/test_environs.py @@ -5,6 +5,7 @@ import urllib.parse import pathlib from decimal import Decimal +from enum import Enum import dj_database_url import dj_email_url @@ -35,6 +36,12 @@ class FauxTestException(Exception): pass +class DayEnum(Enum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = 3 + + class TestCasting: def test_call(self, set_env, env): set_env({"STR": "foo", "INT": "42"}) @@ -204,6 +211,24 @@ def test_invalid_url(self, url, set_env, env): env.url("URL") assert 'Environment variable "URL" invalid' in excinfo.value.args[0] + def test_enum_cast(self, set_env, env): + set_env({"DAY": "SUNDAY"}) + assert env.enum("DAY", type=DayEnum) == DayEnum.SUNDAY + + def test_enum_cast_ignore_case(self, set_env, env): + set_env({"DAY": "suNDay"}) + assert env.enum("DAY", type=DayEnum, ignore_case=True) == DayEnum.SUNDAY + + def test_invalid_enum(self, set_env, env): + set_env({"DAY": "suNDay"}) + with pytest.raises(environs.EnvError): + assert env.enum("DAY", type=DayEnum) + + def test_invalid_enum_ignore_case(self, set_env, env): + set_env({"DAY": "SonDAY"}) + with pytest.raises(environs.EnvError): + assert env.enum("DAY", type=DayEnum, ignore_case=True) + class TestEnvFileReading: def test_read_env(self, env): @@ -320,17 +345,17 @@ def https_url(value): def test_parser_function_can_take_extra_arguments(self, set_env, env): set_env({"ENV": "dev"}) - @env.parser_for("enum") - def enum_parser(value, choices): + @env.parser_for("choice") + def choice_parser(value, choices): if value not in choices: raise environs.EnvError("Invalid!") return value - assert env.enum("ENV", choices=["dev", "prod"]) == "dev" + assert env.choice("ENV", choices=["dev", "prod"]) == "dev" set_env({"ENV": "invalid"}) with pytest.raises(environs.EnvError): - env.enum("ENV", choices=["dev", "prod"]) + env.choice("ENV", choices=["dev", "prod"]) def test_add_parser_from_field(self, set_env, env): class HTTPSURL(fields.Field):