|
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | 4 | from functools import lru_cache |
| 5 | +from pathlib import Path |
5 | 6 |
|
| 7 | +import yaml |
6 | 8 | from anomalib.models import list_models |
7 | 9 | from fastapi import APIRouter |
| 10 | +from loguru import logger |
8 | 11 |
|
9 | 12 | from api.endpoints import API_PREFIX |
10 | | -from pydantic_models import TrainableModelList |
| 13 | +from pydantic_models import TrainableModel, TrainableModelList |
11 | 14 |
|
12 | 15 | router = APIRouter(prefix=f"{API_PREFIX}/trainable-models", tags=["Trainable Models"]) |
13 | 16 |
|
| 17 | +CORE_DIR = Path(__file__).parent.parent.parent / "core" |
| 18 | + |
14 | 19 |
|
15 | 20 | @lru_cache |
16 | | -def _get_trainable_models() -> TrainableModelList: # pragma: no cover |
17 | | - """Return list of trainable models with optional descriptions. |
| 21 | +def _load_model_metadata() -> TrainableModelList: |
| 22 | + """Load and cache model metadata from YAML configuration file. |
| 23 | +
|
| 24 | + The model metadata is stored in ``core/model_metadata.yaml`` and contains |
| 25 | + information about each model including display name, description, family, |
| 26 | + performance metrics, and whether it is recommended/supported. |
| 27 | +
|
| 28 | + Models are validated against anomalib's available models list. Any model |
| 29 | + in the metadata that doesn't exist in anomalib will be logged as a warning |
| 30 | + and excluded from the results. |
18 | 31 |
|
19 | | - The available models are retrieved from ``anomalib.models.list_models``. Currently, only |
20 | | - the model names are returned. Descriptions can be added manually in the |
21 | | - ``_MODEL_DESCRIPTIONS`` mapping below. |
| 32 | + Returns: |
| 33 | + TrainableModelList containing only supported models that exist in |
| 34 | + anomalib, sorted by recommended first, then by name. |
22 | 35 | """ |
23 | | - model_names = sorted(list_models(case="pascal")) |
| 36 | + metadata_path = CORE_DIR / "model_metadata.yaml" |
| 37 | + |
| 38 | + with open(metadata_path, encoding="utf-8") as f: |
| 39 | + raw_models = yaml.safe_load(f) |
| 40 | + |
| 41 | + if not raw_models: |
| 42 | + return TrainableModelList(trainable_models=[]) |
| 43 | + |
| 44 | + available_models = set(list_models(case="snake")) |
24 | 45 |
|
25 | | - return TrainableModelList(trainable_models=model_names) |
| 46 | + models = [] |
| 47 | + for m in raw_models: |
| 48 | + if not m.get("supported", True): |
| 49 | + continue |
| 50 | + |
| 51 | + model_id = m.get("id", "") |
| 52 | + if model_id not in available_models: |
| 53 | + logger.warning(f"Model '{model_id}' from metadata not found in anomalib.models.list_models()") |
| 54 | + continue |
| 55 | + |
| 56 | + models.append(TrainableModel(**m)) |
| 57 | + |
| 58 | + models.sort(key=lambda m: (not m.recommended, m.name)) |
| 59 | + |
| 60 | + return TrainableModelList(trainable_models=models) |
26 | 61 |
|
27 | 62 |
|
28 | 63 | @router.get("", summary="List trainable models") |
29 | 64 | async def list_trainable_models() -> TrainableModelList: |
30 | | - """GET endpoint returning available trainable model names.""" |
| 65 | + """GET endpoint returning available trainable model metadata. |
31 | 66 |
|
32 | | - return _get_trainable_models() |
| 67 | + Returns a list of trainable models with their metadata including: |
| 68 | + - id: Model identifier for training API |
| 69 | + - name: Human-readable display name |
| 70 | + - description: Brief model description |
| 71 | + - family: Model architecture family (memory_bank, distribution, etc.) |
| 72 | + - recommended: Whether model is recommended for new users |
| 73 | + - metrics: Training and inference speed scores (1-3 scale) |
| 74 | + - parameters: Model size in millions of parameters |
| 75 | + """ |
| 76 | + return _load_model_metadata() |
0 commit comments