Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
480b4e5
feat: Updated BaseConfig class for non primitive fields
abhiramvad May 14, 2025
eba84d2
Merge branch 'oumi-ai:main' into base-config-non-primitive-types
abhiramvad May 17, 2025
35b3839
Merge branch 'oumi-ai:main' into base-config-non-primitive-types
abhiramvad May 24, 2025
26d6889
Merge branch 'oumi-ai:main' into base-config-non-primitive-types
abhiramvad Jun 7, 2025
148a4c9
review comments
abhiramvad Jun 11, 2025
69e07ea
Merge branch 'oumi-ai:main' into base-config-non-primitive-types
abhiramvad Jun 11, 2025
8d6d551
basic config tests
abhiramvad Jun 11, 2025
305bfda
Merge branch 'main' into base-config-non-primitive-types
wizeng23 Jun 11, 2025
9a2439f
fixed pre commit check issues
abhiramvad Jul 20, 2025
118c98b
Merge branch 'main' into base-config-non-primitive-types
abhiramvad Jul 20, 2025
7720a7e
Merge branch 'main' into base-config-non-primitive-types
abhiramvad Jul 21, 2025
0aec965
fixed test failures
abhiramvad Jul 26, 2025
a3fd539
Merge branch 'main' into base-config-non-primitive-types
abhiramvad Jul 27, 2025
f01e1b3
Merge branch 'oumi-ai:main' into base-config-non-primitive-types
abhiramvad Aug 7, 2025
1d12c39
Update base_config.py
abhiramvad Aug 7, 2025
3ca7cca
Update test_base_config.py
abhiramvad Aug 7, 2025
2d87fc3
Merge branch 'main' into base-config-non-primitive-types
wizeng23 Aug 8, 2025
a756ae2
fixed pre commit errors
abhiramvad Aug 17, 2025
beaed70
Merge branch 'oumi-ai:main' into base-config-non-primitive-types
abhiramvad Aug 17, 2025
7d7e943
Merge branch 'main' into base-config-non-primitive-types
wizeng23 Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 98 additions & 3 deletions src/oumi/core/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

import dataclasses
import inspect
import logging
import re
from collections.abc import Iterator
from enum import Enum
from io import StringIO
from pathlib import Path
from typing import Any, Optional, TypeVar, Union, cast
Expand All @@ -28,6 +30,74 @@

_CLI_IGNORED_PREFIXES = ["--local-rank"]

# Set of primitive types that OmegaConf can handle directly
_PRIMITIVE_TYPES = {str, int, float, bool, type(None), bytes, Path, Enum}


def _is_primitive_type(value: Any) -> bool:
"""Check if a value is of a primitive type that OmegaConf can handle."""
return (
isinstance(value, (str, int, float, bool, bytes))
or value is None
or isinstance(value, Path)
or isinstance(value, Enum)
)


def _handle_non_primitives(config: Any, removed_paths, path: str = "") -> Any:
"""Recursively process config object to handle non-primitive values.

Args:
config: The config object to process
removed_paths: Set to track paths of removed non-primitive values
path: The current path in the config (for logging)

Returns:
The processed config with non-primitive values removed
"""
if isinstance(config, list):
return [
_handle_non_primitives(item, removed_paths, f"{path}[{i}]")
for i, item in enumerate(config)
]

if isinstance(config, dict):
result = {}
for key, value in config.items():
current_path = f"{path}.{key}" if path else key
if _is_primitive_type(value):
result[key] = value
else:
# Recursively process nested dictionaries and other non-primitive values
processed_value = _handle_non_primitives(
value, removed_paths, current_path
)
if processed_value is not None:
result[key] = processed_value
else:
removed_paths.add(current_path)
result[key] = None
return result

if _is_primitive_type(config):
return config

# Try to convert functions to their source code
if callable(config):
try:
# Lambda functions and built-in functions can't have source extracted
source = inspect.getsource(config)
# Only return source if we successfully got it
return source
except (TypeError, OSError):
# Can't get source for lambdas, built-ins, or C extensions
removed_paths.add(path)
return None

# For any other type, remove it and track the path
removed_paths.add(path)
return None


def _filter_ignored_args(arg_list: list[str]) -> list[str]:
"""Filters out ignored CLI arguments."""
Expand Down Expand Up @@ -57,8 +127,31 @@ def _read_config_without_interpolation(config_path: str) -> str:
@dataclasses.dataclass
class BaseConfig:
def to_yaml(self, config_path: Union[str, Path, StringIO]) -> None:
"""Saves the configuration to a YAML file."""
OmegaConf.save(config=self, f=config_path)
"""Saves the configuration to a YAML file.

Non-primitive values are removed and warnings are logged.

Args:
config_path: Path to save the config to
"""
# Convert the dataclass to an OmegaConf structure first
omega_config = OmegaConf.structured(self)
config_dict = OmegaConf.to_container(omega_config, resolve=True)
removed_paths = set()
processed_config = _handle_non_primitives(
config_dict, removed_paths=removed_paths
)

# Log warnings for removed values
if removed_paths:
logger = logging.getLogger(__name__)
logger.warning(
"The following non-primitive values were removed from the config "
"as they cannot be saved to YAML:\n"
+ "\n".join(f"- {path}" for path in sorted(removed_paths))
)

OmegaConf.save(config=processed_config, f=config_path)

@classmethod
def from_yaml(
Expand Down Expand Up @@ -182,7 +275,9 @@ def print_config(self, logger: Optional[logging.Logger] = None) -> None:
if logger is None:
logger = logging.getLogger(__name__)

config_yaml = OmegaConf.to_yaml(self, resolve=True)
# Convert the dataclass to an OmegaConf structure first
omega_config = OmegaConf.structured(self)
config_yaml = OmegaConf.to_yaml(omega_config, resolve=True)
logger.info(f"Configuration:\n{config_yaml}")

def finalize_and_validate(self) -> None:
Expand Down
Loading
Loading