Skip to content

Commit d10993e

Browse files
hramezaniclaude
andauthored
Fix nested discriminated unions not discovered by env/CLI providers (#816)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ba145e9 commit d10993e

File tree

4 files changed

+77
-3
lines changed

4 files changed

+77
-3
lines changed

pydantic_settings/sources/providers/cli.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ class CliMutuallyExclusiveGroup(BaseModel):
9090
pass
9191

9292

93+
def _collect_sub_models(type_: Any, sub_models: list[type[BaseModel]]) -> None:
94+
"""Recursively collect BaseModel subclasses from possibly nested union types."""
95+
stripped = _strip_annotated(type_)
96+
if is_model_class(stripped) or is_pydantic_dataclass(stripped):
97+
sub_models.append(stripped) # type: ignore[arg-type]
98+
elif is_union_origin(get_origin(stripped)):
99+
for arg in get_args(stripped):
100+
_collect_sub_models(arg, sub_models)
101+
102+
93103
class _CliArg(BaseModel):
94104
model: Any
95105
parser: Any
@@ -200,8 +210,7 @@ def sub_models(self) -> list[type[BaseModel]]:
200210
raise SettingsError(
201211
f'CliPositionalArg is not outermost annotation for {self.model.__name__}.{self.field_name}'
202212
)
203-
if is_model_class(_strip_annotated(type_)) or is_pydantic_dataclass(_strip_annotated(type_)):
204-
sub_models.append(_strip_annotated(type_))
213+
_collect_sub_models(type_, sub_models)
205214
return sub_models
206215

207216
@cached_property

pydantic_settings/sources/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,22 @@ def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
134134

135135
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
136136
"""Check if a union type contains any complex types."""
137-
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
137+
for arg in get_args(annotation):
138+
if _annotation_is_complex(arg, metadata):
139+
return True
140+
# _annotation_is_complex doesn't handle bare Union types, so when an arg
141+
# is Annotated[Union[X, Y], ...], stripping Annotated yields a bare Union
142+
# that _annotation_is_complex can't evaluate. Recurse into it, but only
143+
# if the Annotated metadata doesn't suppress complexity (e.g. Json).
144+
inner = _strip_annotated(arg)
145+
if inner is not arg:
146+
_, *inner_meta = get_args(arg)
147+
if any(isinstance(md, Json) for md in inner_meta): # type: ignore[misc]
148+
continue
149+
if is_union_origin(get_origin(inner)):
150+
if _union_is_complex(inner, metadata):
151+
return True
152+
return False
138153

139154

140155
def _union_has_strict_types(annotation: type[Any] | None) -> bool:

tests/test_settings.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,6 +2362,31 @@ class Settings(BaseSettings):
23622362
assert s.a_or_b.x == 'a'
23632363

23642364

2365+
def test_nested_discriminated_union(env):
2366+
"""Test that nested unions like Union[Annotated[Union[A, B], Discriminator(...)], None] are handled correctly."""
2367+
2368+
class A(BaseModel):
2369+
x: Literal['a'] = 'a'
2370+
a: int = 1
2371+
2372+
class B(BaseModel):
2373+
x: Literal['b'] = 'b'
2374+
b: str = 'hello'
2375+
2376+
AOrB = Annotated[A | B, Discriminator('x')]
2377+
2378+
class Settings(BaseSettings):
2379+
model_config = SettingsConfigDict(env_nested_delimiter='__')
2380+
a_or_b: AOrB | None = None
2381+
2382+
# Test env var for nested model field
2383+
env.set('a_or_b__x', 'b')
2384+
env.set('a_or_b__b', 'world')
2385+
s = Settings()
2386+
assert isinstance(s.a_or_b, B)
2387+
assert s.a_or_b.b == 'world'
2388+
2389+
23652390
def test_nested_model_case_insensitive(env):
23662391
class SubSubSub(BaseModel):
23672392
VaL3: str

tests/test_source_cli.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,31 @@ class Cfg(BaseSettings):
14491449
assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}}
14501450

14511451

1452+
def test_cli_nested_discriminated_union():
1453+
"""Test that nested unions like Union[Annotated[Union[A, B], Discriminator(...)], None] are handled by CLI."""
1454+
1455+
class A(BaseModel):
1456+
x: Literal['a'] = 'a'
1457+
a: int = 1
1458+
1459+
class B(BaseModel):
1460+
x: Literal['b'] = 'b'
1461+
b: str = 'hello'
1462+
1463+
AOrB = Annotated[A | B, Discriminator('x')]
1464+
1465+
class Cfg(BaseSettings):
1466+
a_or_b: AOrB | None = None
1467+
1468+
cfg = CliApp.run(Cfg, cli_args=['--a_or_b.x', 'b', '--a_or_b.b', 'world'])
1469+
assert isinstance(cfg.a_or_b, B)
1470+
assert cfg.a_or_b.b == 'world'
1471+
1472+
cfg = CliApp.run(Cfg, cli_args=['--a_or_b.x', 'a', '--a_or_b.a', '42'])
1473+
assert isinstance(cfg.a_or_b, A)
1474+
assert cfg.a_or_b.a == 42
1475+
1476+
14521477
def test_cli_optional_positional_arg(env):
14531478
class Main(BaseSettings):
14541479
model_config = SettingsConfigDict(

0 commit comments

Comments
 (0)