Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ The following are all type-casting methods of `Env`:
- `env.uuid`
- `env.log_level`
- `env.path` (casts to a [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html))
- `env.enum` (casts to any given enum type in `type` keyword argument)

## Reading `.env` files

Expand Down
12 changes: 12 additions & 0 deletions environs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import typing
import types
from collections.abc import Mapping
from enum import Enum
from urllib.parse import urlparse, ParseResult
from pathlib import Path

Expand Down Expand Up @@ -176,6 +177,16 @@ def _preprocess_json(value: str, **kwargs):
return pyjson.loads(value)


_EnumT = typing.TypeVar("_EnumT", bound=Enum)


def _enum_parser(value, *, type: typing.Type[_EnumT], **kwargs) -> _EnumT:
try:
return type[value]
except Exception:
Comment thread
tomgrin10 marked this conversation as resolved.
Outdated
raise ma.ValidationError(f"Not a valid {type} enum.")
Comment thread
tomgrin10 marked this conversation as resolved.
Outdated


def _dj_db_url_parser(value: str, **kwargs) -> dict:
try:
import dj_database_url
Expand Down Expand Up @@ -275,6 +286,7 @@ class Env:
timedelta = _field2method(ma.fields.TimeDelta, "timedelta")
uuid = _field2method(ma.fields.UUID, "uuid")
url = _field2method(URLField, "url")
enum = _func2method(_enum_parser, "enum")
dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url")
dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url")
dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url")
Expand Down
19 changes: 15 additions & 4 deletions tests/test_environs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import urllib.parse
import pathlib
from decimal import Decimal
from enum import Enum

import dj_database_url
import dj_email_url
Expand Down Expand Up @@ -35,6 +36,12 @@ class FauxTestException(Exception):
pass


class DayEnum(Enum):
SUNDAY = 1
MONDAY = 2
TUESDAY = 3


class TestCasting:
def test_call(self, set_env, env):
set_env({"STR": "foo", "INT": "42"})
Expand Down Expand Up @@ -200,6 +207,10 @@ def test_invalid_url(self, url, set_env, env):
env.url("URL")
assert 'Environment variable "URL" invalid' in excinfo.value.args[0]

def test_enum_cast(self, set_env, env):
set_env({"DAY": "SUNDAY"})
assert env.enum("DAY", type=DayEnum) == DayEnum.SUNDAY


class TestEnvFileReading:
def test_read_env(self, env):
Expand Down Expand Up @@ -316,17 +327,17 @@ def https_url(value):
def test_parser_function_can_take_extra_arguments(self, set_env, env):
set_env({"ENV": "dev"})

@env.parser_for("enum")
def enum_parser(value, choices):
@env.parser_for("choice")
def choice_parser(value, choices):
Comment thread
sloria marked this conversation as resolved.
if value not in choices:
raise environs.EnvError("Invalid!")
return value

assert env.enum("ENV", choices=["dev", "prod"]) == "dev"
assert env.choice("ENV", choices=["dev", "prod"]) == "dev"

set_env({"ENV": "invalid"})
with pytest.raises(environs.EnvError):
env.enum("ENV", choices=["dev", "prod"])
env.choice("ENV", choices=["dev", "prod"])

def test_add_parser_from_field(self, set_env, env):
class HTTPSURL(fields.Field):
Expand Down