Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 73 additions & 210 deletions src/gpgi/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@
from __future__ import annotations

import enum
import math
import sys
import warnings
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property, partial, reduce
from itertools import chain
from textwrap import indent
from threading import Lock
from time import monotonic_ns
from typing import TYPE_CHECKING, Literal, assert_never, cast
from typing import TYPE_CHECKING, Literal, cast, final

import numpy as np

Expand All @@ -32,6 +28,14 @@
_deposit_tsc_3D,
_index_particles,
)
from gpgi._spatial_data import (
BasicCoordinatesValidator,
DTypeConsistencyValidator,
FieldMapsValidatorHelper,
Geometry,
GeometryValidator,
Validator,
)
from gpgi._typing import FieldMap, Name
from gpgi.typing import DepositionMethodT, DepositionMethodWithMetadataT

Expand All @@ -43,22 +47,12 @@
if TYPE_CHECKING:
from typing import Any, Self

from numpy.typing import NDArray

from gpgi._typing import FieldMap, HCIArray, Name, Real, RealArray
from gpgi._typing import FieldMap, HCIArray, Name, RealArray


BoundarySpec = tuple[tuple[str, str, str], ...]


class Geometry(enum.StrEnum):
CARTESIAN = enum.auto()
POLAR = enum.auto()
CYLINDRICAL = enum.auto()
SPHERICAL = enum.auto()
EQUATORIAL = enum.auto()


class DepositionMethod(enum.Enum):
NEAREST_GRID_POINT = enum.auto()
CLOUD_IN_CELL = enum.auto()
Expand Down Expand Up @@ -90,177 +84,28 @@ class DepositionMethod(enum.Enum):
}


class GeometricData(ABC):
geometry: Geometry
axes: tuple[Name, ...]


class CoordinateData(ABC):
geometry: Geometry
axes: tuple[Name, ...]
coordinates: FieldMap
fields: FieldMap


@dataclass
class ReferenceArray:
name: Name
data: NDArray


class ValidatorMixin(GeometricData, ABC):
def __init__(self) -> None:
self._validate()

@abstractmethod
def _validate(self) -> None: ...

def _validate_FieldMaps(
self,
*fmaps: FieldMap | None,
require_shape_equality: bool = False,
require_sorted: bool = False,
**required_attrs: Any,
) -> None:
ref_arr: ReferenceArray | None = None
for name, data in chain.from_iterable(
fm.items() for fm in fmaps if fm is not None
):
if require_shape_equality:
ref_arr = self._validate_shape_equality(name, data, ref_arr)
if require_sorted:
self._validate_sorted_state(name, data)
if required_attrs:
self._validate_required_attributes(name, data, **required_attrs)

def _validate_shape_equality(
self,
name: str,
data: NDArray[Real],
ref_arr: ReferenceArray | None,
) -> ReferenceArray:
if ref_arr is not None and data.shape != ref_arr.data.shape:
raise ValueError(
f"Fields {name!r} and {ref_arr.name!r} "
f"have mismatching shapes {data.shape} and {ref_arr.data.shape}"
)
return ReferenceArray(name, data)

def _validate_sorted_state(self, name: str, data: NDArray[Real]) -> None:
a = data[0]
for i, b in enumerate(data[1:], start=1):
if a > b:
raise ValueError(
f"Field {name!r} is not properly sorted by ascending order. "
f"Got {a} (index {i-1}) > {b} (index {i})"
)
a = b

def _validate_required_attributes(
self,
name: str,
data: NDArray[Real],
**required_attrs: Any,
) -> None:
for attr, expected in required_attrs.items():
if (actual := getattr(data, attr)) != expected:
raise ValueError(
f"Field {name!r} has incorrect {attr} {actual} "
f"(expected {expected})"
)

def _validate_geometry(self) -> None:
match self.geometry:
case Geometry.CARTESIAN:
axes3D = ("x", "y", "z")
case Geometry.POLAR:
axes3D = ("radius", "azimuth", "z")
case Geometry.CYLINDRICAL:
axes3D = ("radius", "z", "azimuth")
case Geometry.SPHERICAL:
axes3D = ("radius", "colatitude", "azimuth")
case Geometry.EQUATORIAL:
axes3D = ("radius", "azimuth", "latitude")
case _ as unreachable: # pragma: no cover
assert_never(unreachable)

for i, (expected, actual) in enumerate(zip(axes3D, self.axes, strict=False)):
if actual != expected:
raise ValueError(
f"Got invalid axis name {actual!r} on position {i}, "
f"with geometry {self.geometry.name.lower()!r}\n"
f"Expected axes ordered as {axes3D[: len(self.axes)]}"
)


_AXES_LIMITS: dict[Name, tuple[float, float]] = {
"x": (-float("inf"), float("inf")),
"y": (-float("inf"), float("inf")),
"z": (-float("inf"), float("inf")),
"radius": (0, float("inf")),
"azimuth": (0, 2 * np.pi),
"colatitude": (0, np.pi),
"latitude": (-np.pi / 2, np.pi / 2),
}


class _CoordinateValidatorMixin(ValidatorMixin, CoordinateData, ABC):
def __init__(self) -> None:
super().__init__()
dts = {
name: arr.dtype
for name, arr in chain(
self.coordinates.items(),
self.fields.items(),
)
}
unique_dts = sorted(set(dts.values()))
if len(unique_dts) > 1:
raise TypeError(f"Received mixed data types ({unique_dts}):\n{dts}")
else:
self.dtype = unique_dts.pop()

def _validate_coordinates(self) -> None:
for axis in self.axes:
coord = self.coordinates[axis]
if len(coord) == 0:
continue
coord_dtype = self._get_safe_datatype(coord)
dt = coord_dtype.type
xmin, xmax = (dt(_) for _ in _AXES_LIMITS[axis])
if (cmin := dt(np.min(coord))) < xmin or not math.isfinite(cmin):
if math.isfinite(xmin):
hint = f"minimal value allowed is {xmin}"
else:
assert xmin == -float("inf")
hint = "value must be finite"
raise ValueError(
f"Invalid coordinate data for axis {axis!r} {cmin} ({hint})"
)
if (cmax := dt(np.max(coord))) > xmax or not math.isfinite(cmax):
if math.isfinite(xmax):
hint = f"maximal value allowed is {xmax}"
else:
assert xmax == float("inf")
hint = "value must be finite"
raise ValueError(
f"Invalid coordinate data for axis {axis!r} {cmax} ({hint})"
)

self.coordinates[axis] = coord.astype(coord_dtype, copy=False)

def _get_safe_datatype(
self, reference: NDArray[np.floating] | None = None
) -> np.dtype[np.floating]:
if reference is None:
reference = self.coordinates[self.axes[0]]
dt = reference.dtype
if dt.kind != "f":
raise ValueError(f"Invalid data type {dt} (expected a float dtype)")
return dt
# the following need to be defined in the same module as Grid and ParticleSet
@final
class GridFieldMapsValidator:
@classmethod
def check(cls, data: Grid) -> None:
FieldMapsValidatorHelper.check(
data.coordinates,
require_sorted=True,
required_attrs={"ndim": 1},
)
FieldMapsValidatorHelper.check(
data.fields,
required_attrs={
"size": data.size,
"ndim": data.ndim,
"shape": data.shape,
},
)


class Grid(_CoordinateValidatorMixin):
@final
class Grid:
def __init__(
self,
*,
Expand Down Expand Up @@ -289,14 +134,26 @@ def __init__(
self.fields: FieldMap = fields

self.axes = tuple(self.coordinates.keys())
super().__init__()
self._validate()
self.dtype = self.coordinates[self.axes[0]].dtype

self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype)
for i, ax in enumerate(self.axes):
if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16:
# got a constant step in this direction, store it
self._dx[i] = self.coordinates[ax][1] - self.coordinates[ax][0]

_validators: list[type[Validator[Grid]]] = [
GeometryValidator,
BasicCoordinatesValidator,
GridFieldMapsValidator,
DTypeConsistencyValidator,
]

def _validate(self) -> None:
for validator in self.__class__._validators:
validator.check(self)

def __repr__(self) -> str:
"""Implement repr(Grid(...))."""
return (
Expand All @@ -307,17 +164,6 @@ def __repr__(self) -> str:
")"
)

def _validate(self) -> None:
self._validate_geometry()
self._validate_coordinates()
self._validate_FieldMaps(self.cell_edges, ndim=1, require_sorted=True)
self._validate_FieldMaps(
self.fields,
size=self.size,
ndim=self.ndim,
shape=self.shape,
)

@property
def cell_edges(self) -> FieldMap:
r"""An alias for self.coordinates."""
Expand Down Expand Up @@ -371,7 +217,19 @@ def cell_volumes(self) -> RealArray:
)


class ParticleSet(_CoordinateValidatorMixin):
@final
class ParticleSetFieldMapsValidator:
@classmethod
def check(cls, data: ParticleSet) -> None:
FieldMapsValidatorHelper.check(
data.coordinates,
require_shape_equality=True,
required_attrs={"ndim": 1},
)


@final
class ParticleSet:
def __init__(
self,
*,
Expand Down Expand Up @@ -399,7 +257,19 @@ def __init__(
self.fields: FieldMap = fields

self.axes = tuple(self.coordinates.keys())
super().__init__()
self._validate()
self.dtype = self.coordinates[self.axes[0]].dtype

_validators: list[type[Validator[ParticleSet]]] = [
GeometryValidator,
BasicCoordinatesValidator,
ParticleSetFieldMapsValidator,
DTypeConsistencyValidator,
]

def _validate(self) -> None:
for validator in self.__class__._validators:
validator.check(self)

def __repr__(self) -> str:
"""Implement repr(ParticleSet(...))."""
Expand All @@ -411,13 +281,6 @@ def __repr__(self) -> str:
")"
)

def _validate(self) -> None:
self._validate_geometry()
self._validate_coordinates()
self._validate_FieldMaps(
self.coordinates, self.fields, require_shape_equality=True, ndim=1
)

@property
def count(self) -> int:
r"""The total number of particles in the set."""
Expand All @@ -429,7 +292,8 @@ def ndim(self) -> int:
return len(self.axes)


class Dataset(ValidatorMixin):
@final
class Dataset:
def __init__(
self,
*,
Expand Down Expand Up @@ -461,10 +325,9 @@ def __init__(
self.geometry = geometry

if particles is None:
dt = grid._get_safe_datatype()
particles = ParticleSet(
geometry=grid.geometry,
coordinates={ax: np.array([], dtype=dt) for ax in grid.axes},
coordinates={ax: np.array([], dtype=grid.dtype) for ax in grid.axes},
)

self.grid: Grid = grid
Expand All @@ -478,7 +341,7 @@ def __init__(
self._hci_lock = Lock()
self._deposit_lock = Lock()

super().__init__()
self._validate()

def __repr__(self) -> Name:
"""Implement repr(Dataset(...))."""
Expand Down
Loading
Loading