diff --git a/.gitignore b/.gitignore index 90710c6db9..7977055ce0 100644 --- a/.gitignore +++ b/.gitignore @@ -109,3 +109,4 @@ doc/source/quickstart/.ipynb_checkpoints/ dist .python-version answer_nosetests.xml +.venv/ diff --git a/yt/_typing.py b/yt/_typing.py index 642717632b..8449a26ab7 100644 --- a/yt/_typing.py +++ b/yt/_typing.py @@ -1,6 +1,7 @@ -from typing import Any, Optional, TypeAlias +from typing import Optional, TypeAlias import numpy as np +import numpy.typing as npt import unyt as un FieldDescT = tuple[str, tuple[str, list[str], str | None]] @@ -12,12 +13,12 @@ FieldKey = tuple[FieldType, FieldName] ImplicitFieldKey = FieldName AnyFieldKey = FieldKey | ImplicitFieldKey -DomainDimensions = tuple[int, ...] | list[int] | np.ndarray +DomainDimensions = tuple[int, ...] | list[int] | npt.NDArray ParticleCoordinateTuple = tuple[ str, # particle type - tuple[np.ndarray, np.ndarray, np.ndarray], # xyz - float | np.ndarray, # hsml + tuple[npt.NDArray, npt.NDArray, npt.NDArray], # xyz + float | npt.NDArray, # hsml ] # Geometry specific types @@ -33,5 +34,5 @@ # np.ndarray[...] syntax is runtime-valid from numpy 1.22, we quote it until our minimal # runtime requirement is bumped to, or beyond this version -MaskT = Optional["np.ndarray[Any, np.dtype[np.bool_]]"] -AlphaT = Optional["np.ndarray[Any, np.dtype[np.float64]]"] +MaskT = Optional["npt.NDArray[np.bool_]"] +AlphaT = Optional["npt.NDArray[np.float64]"] diff --git a/yt/frontends/artio/data_structures.py b/yt/frontends/artio/data_structures.py index 2338d40e00..1a29dc0721 100644 --- a/yt/frontends/artio/data_structures.py +++ b/yt/frontends/artio/data_structures.py @@ -3,6 +3,7 @@ from collections import defaultdict import numpy as np +import numpy.typing as npt from yt.data_objects.field_data import YTFieldData from yt.data_objects.index_subobjects.octree_subset import OctreeSubset @@ -339,10 +340,10 @@ def _read_fluid_fields(self, fields, dobj, chunk=None): def _icoords_to_fcoords( self, - icoords: np.ndarray, - ires: np.ndarray, + icoords: npt.NDArray, + ires: npt.NDArray, axes: tuple[int, ...] | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> tuple[npt.NDArray, npt.NDArray]: """ Accepts icoords and ires and returns appropriate fcoords and fwidth. Mostly useful for cases where we have irregularly spaced or structured diff --git a/yt/frontends/ramses/hilbert.py b/yt/frontends/ramses/hilbert.py index 5f203c8f79..460321ab80 100644 --- a/yt/frontends/ramses/hilbert.py +++ b/yt/frontends/ramses/hilbert.py @@ -1,6 +1,7 @@ -from typing import Any, Optional +from typing import Optional import numpy as np +import numpy.typing as npt from yt.data_objects.selection_objects.region import YTRegion from yt.geometry.selection_routines import ( @@ -48,9 +49,7 @@ ) -def hilbert3d( - ijk: "np.ndarray[Any, np.dtype[np.int64]]", bit_length: int -) -> "np.ndarray[Any, np.dtype[np.float64]]": +def hilbert3d(ijk: "npt.NDArray[np.int64]", bit_length: int) -> "npt.NDArray[np.int64]": """Compute the order using Hilbert indexing. Arguments @@ -70,11 +69,11 @@ def hilbert3d( def get_intersecting_cpus( ds, region: YTRegion, - LE: Optional["np.ndarray[Any, np.dtype[np.float64]]"] = None, + LE: Optional["npt.NDArray[np.float64]"] = None, dx: float = 1.0, dx_cond: float | None = None, factor: float = 4.0, - bound_keys: Optional["np.ndarray[Any, np.dtype[np.float64]]"] = None, + bound_keys: Optional["npt.NDArray[np.float64]"] = None, ) -> set[int]: """ Find the subset of CPUs that intersect the bbox in a recursive fashion. @@ -119,8 +118,8 @@ def get_intersecting_cpus( def get_cpu_list_cuboid( ds, - X: "np.ndarray[Any, np.dtype[np.float64]]", - bound_keys: "np.ndarray[Any, np.dtype[np.float64]]", + X: "npt.NDArray[np.float64]", + bound_keys: "npt.NDArray[np.float64]", ) -> set[int]: """ Return the list of the CPU intersecting with the cuboid containing the positions. diff --git a/yt/frontends/ramses/io.py b/yt/frontends/ramses/io.py index aae9258535..cd48fab2dc 100644 --- a/yt/frontends/ramses/io.py +++ b/yt/frontends/ramses/io.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Union import numpy as np +import numpy.typing as npt from unyt import unyt_array from yt._maintenance.deprecation import issue_deprecation_warning @@ -37,7 +38,7 @@ def convert_ramses_ages(ds, conformal_ages): def convert_ramses_conformal_time_to_physical_time( - ds, conformal_time: np.ndarray + ds, conformal_time: npt.NDArray ) -> unyt_array: """ Convert conformal times (as defined in RAMSES) to physical times. @@ -82,7 +83,7 @@ def _ramses_particle_binary_file_handler( subset: "RAMSESDomainSubset", fields: list[FieldKey], count: int, -) -> dict[FieldKey, np.ndarray]: +) -> dict[FieldKey, npt.NDArray]: """General file handler for binary file, called by _read_particle_subset Parameters @@ -96,7 +97,7 @@ def _ramses_particle_binary_file_handler( count: integer The number of elements to count """ - tr = {} + tr: dict[FieldKey, npt.NDArray] = {} ds = subset.domain.ds foffsets = particle_handler.field_offsets fname = particle_handler.fname @@ -130,7 +131,7 @@ def _ramses_particle_csv_file_handler( subset: "RAMSESDomainSubset", fields: list[FieldKey], count: int, -) -> dict[FieldKey, np.ndarray]: +) -> dict[FieldKey, npt.NDArray]: """General file handler for csv file, called by _read_particle_subset Parameters @@ -146,7 +147,7 @@ def _ramses_particle_csv_file_handler( """ from yt.utilities.on_demand_imports import _pandas as pd - tr = {} + tr: dict[FieldKey, npt.NDArray] = {} ds = subset.domain.ds foffsets = particle_handler.field_offsets fname = particle_handler.fname diff --git a/yt/frontends/ramses/particle_handlers.py b/yt/frontends/ramses/particle_handlers.py index 10815f233c..588d2ae200 100644 --- a/yt/frontends/ramses/particle_handlers.py +++ b/yt/frontends/ramses/particle_handlers.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import numpy.typing as npt from yt._typing import FieldKey from yt.config import ytcfg @@ -71,7 +72,7 @@ class ParticleFileHandler(abc.ABC, HandlerMixin): # assumed to be `self`). reader: Callable[ ["RAMSESDomainSubset", list[FieldKey], int], - dict[FieldKey, np.ndarray], + dict[FieldKey, npt.NDArray], ] # Name of the config section (if any) @@ -162,7 +163,7 @@ def header(self) -> dict[str, Any]: self.read_header() return self._header - def handle_field(self, field: FieldKey, data_dict: dict[FieldKey, np.ndarray]): + def handle_field(self, field: FieldKey, data_dict: dict[FieldKey, npt.NDArray]): """ This function allows custom code to be called to handle special cases, such as the particle birth time. @@ -173,7 +174,7 @@ def handle_field(self, field: FieldKey, data_dict: dict[FieldKey, np.ndarray]): ---------- field : FieldKey The field name. - data_dict : dict[FieldKey, np.ndarray] + data_dict : dict[FieldKey, npt.NDArray] A dictionary containing the data. By default, this function does nothing. @@ -346,7 +347,7 @@ def birth_file_fname(self): def has_birth_file(self): return os.path.exists(self.birth_file_fname) - def handle_field(self, field: FieldKey, data_dict: dict[FieldKey, np.ndarray]): + def handle_field(self, field: FieldKey, data_dict: dict[FieldKey, npt.NDArray]): _ptype, fname = field if not (fname == "particle_birth_time" and self.ds.cosmological_simulation): return @@ -492,7 +493,7 @@ def read_header(self): self._field_offsets = field_offsets self._field_types = _pfields - def handle_field(self, field: FieldKey, data_dict: dict[FieldKey, np.ndarray]): + def handle_field(self, field: FieldKey, data_dict: dict[FieldKey, npt.NDArray]): _ptype, fname = field if not (fname == "particle_birth_time" and self.ds.cosmological_simulation): return diff --git a/yt/frontends/rockstar/data_structures.py b/yt/frontends/rockstar/data_structures.py index 3d613d3b5f..e7f8bbe21e 100644 --- a/yt/frontends/rockstar/data_structures.py +++ b/yt/frontends/rockstar/data_structures.py @@ -1,9 +1,10 @@ import glob import os from functools import cached_property -from typing import Any, Optional +from typing import Optional import numpy as np +import numpy.typing as npt from yt.data_objects.static_output import ParticleDataset from yt.frontends.halo_catalog.data_structures import HaloCatalogFile @@ -21,7 +22,7 @@ class RockstarBinaryFile(HaloCatalogFile): header: dict _position_offset: int _member_offset: int - _Npart: "np.ndarray[Any, np.dtype[np.int64]]" + _Npart: "npt.NDArray[np.int64]" _ids_halos: list[int] _file_size: int @@ -47,9 +48,7 @@ def __init__(self, ds, io, filename, file_id, range): super().__init__(ds, io, filename, file_id, range) - def _read_member( - self, ihalo: int - ) -> Optional["np.ndarray[Any, np.dtype[np.int64]]"]: + def _read_member(self, ihalo: int) -> Optional["npt.NDArray[np.int64]"]: if ihalo not in self._ids_halos: return None diff --git a/yt/frontends/stream/misc.py b/yt/frontends/stream/misc.py index 8271b015a4..bd47c05647 100644 --- a/yt/frontends/stream/misc.py +++ b/yt/frontends/stream/misc.py @@ -1,12 +1,13 @@ import numpy as np +import numpy.typing as npt from yt._typing import DomainDimensions def _validate_cell_widths( - cell_widths: list[np.ndarray], + cell_widths: list[npt.NDArray], domain_dimensions: DomainDimensions, -) -> list[np.ndarray]: +) -> list[npt.NDArray]: # check dimensionality if (nwids := len(cell_widths)) != (ndims := len(domain_dimensions)): raise ValueError( diff --git a/yt/geometry/coordinates/coordinate_handler.py b/yt/geometry/coordinates/coordinate_handler.py index aeb5f8833b..4808172ddd 100644 --- a/yt/geometry/coordinates/coordinate_handler.py +++ b/yt/geometry/coordinates/coordinate_handler.py @@ -2,9 +2,10 @@ import weakref from functools import cached_property from numbers import Number -from typing import Any, Literal, overload +from typing import Literal, overload import numpy as np +import numpy.typing as npt from yt._typing import AxisOrder from yt.funcs import fix_unitary, is_sequence, parse_center_array, validate_width_tuple @@ -158,7 +159,7 @@ def pixelize( periodic=True, *, return_mask: Literal[False], - ) -> "np.ndarray[Any, np.dtype[np.float64]]": ... + ) -> "npt.NDArray[np.float64]": ... @overload def pixelize( @@ -172,9 +173,7 @@ def pixelize( periodic=True, *, return_mask: Literal[True], - ) -> tuple[ - "np.ndarray[Any, np.dtype[np.float64]]", "np.ndarray[Any, np.dtype[np.bool_]]" - ]: ... + ) -> tuple["npt.NDArray[np.float64]", "npt.NDArray[np.bool_]"]: ... @abc.abstractmethod def pixelize( diff --git a/yt/geometry/geometry_handler.py b/yt/geometry/geometry_handler.py index c3c08f289e..09f2ef248e 100644 --- a/yt/geometry/geometry_handler.py +++ b/yt/geometry/geometry_handler.py @@ -3,6 +3,7 @@ import weakref import numpy as np +import numpy.typing as npt from yt._maintenance.deprecation import issue_deprecation_warning from yt.config import ytcfg @@ -51,10 +52,10 @@ def _detect_output_fields(self): def _icoords_to_fcoords( self, - icoords: np.ndarray, - ires: np.ndarray, + icoords: npt.NDArray, + ires: npt.NDArray, axes: tuple[int, ...] | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> tuple[npt.NDArray, npt.NDArray]: # What's the use of raising NotImplementedError for this, when it's an # abstract base class? Well, only *some* of the subclasses have it -- # and for those that *don't*, we should not be calling it -- and since diff --git a/yt/geometry/grid_geometry_handler.py b/yt/geometry/grid_geometry_handler.py index c23b50ded5..1954bd73c0 100644 --- a/yt/geometry/grid_geometry_handler.py +++ b/yt/geometry/grid_geometry_handler.py @@ -3,6 +3,7 @@ from collections import defaultdict import numpy as np +import numpy.typing as npt from yt.arraytypes import blankRecordArray from yt.config import ytcfg @@ -447,10 +448,10 @@ def _chunk_io( def _icoords_to_fcoords( self, - icoords: np.ndarray, - ires: np.ndarray, + icoords: npt.NDArray, + ires: npt.NDArray, axes: tuple[int, ...] | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> tuple[npt.NDArray, npt.NDArray]: """ Accepts icoords and ires and returns appropriate fcoords and fwidth. Mostly useful for cases where we have irregularly spaced or structured diff --git a/yt/geometry/oct_geometry_handler.py b/yt/geometry/oct_geometry_handler.py index 6abb32e92c..f63c713b0b 100644 --- a/yt/geometry/oct_geometry_handler.py +++ b/yt/geometry/oct_geometry_handler.py @@ -1,4 +1,5 @@ import numpy as np +import numpy.typing as npt from yt.fields.field_detector import FieldDetector from yt.geometry.geometry_handler import Index @@ -119,10 +120,10 @@ def _mesh_sampling_particle_field(data): def _icoords_to_fcoords( self, - icoords: np.ndarray, - ires: np.ndarray, + icoords: npt.NDArray, + ires: npt.NDArray, axes: tuple[int, ...] | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> tuple[npt.NDArray, npt.NDArray]: """ Accepts icoords and ires and returns appropriate fcoords and fwidth. Mostly useful for cases where we have irregularly spaced or structured diff --git a/yt/loaders.py b/yt/loaders.py index 27423a5c7d..908bc9d828 100644 --- a/yt/loaders.py +++ b/yt/loaders.py @@ -15,6 +15,7 @@ from urllib.parse import urlsplit import numpy as np +import numpy.typing as npt from more_itertools import always_iterable from yt._maintenance.deprecation import ( @@ -687,7 +688,7 @@ def load_amr_grids( def load_particles( - data: Mapping[AnyFieldKey, np.ndarray | tuple[np.ndarray, str]], + data: Mapping[AnyFieldKey, npt.NDArray | tuple[npt.NDArray, str]], length_unit=None, bbox=None, sim_time=None, @@ -826,7 +827,7 @@ def parse_unit(unit, dimension): field_units, data, _ = process_data(data) sfh = StreamDictFieldHandler() - pdata: dict[AnyFieldKey, np.ndarray | tuple[np.ndarray, str]] = {} + pdata: dict[AnyFieldKey, npt.NDArray | tuple[npt.NDArray, str]] = {} for key in data.keys(): field: FieldKey if not isinstance(key, tuple): @@ -1818,7 +1819,7 @@ def load_hdf5_file( fn: Union[str, "os.PathLike[str]"], root_node: str | None = "/", fields: list[str] | None = None, - bbox: np.ndarray | None = None, + bbox: npt.NDArray | None = None, nchunks: int = 0, dataset_arguments: dict | None = None, ): diff --git a/yt/utilities/io_handler.py b/yt/utilities/io_handler.py index d5aeee1169..dd0adb41f2 100644 --- a/yt/utilities/io_handler.py +++ b/yt/utilities/io_handler.py @@ -5,6 +5,7 @@ from functools import _make_key, lru_cache import numpy as np +import numpy.typing as npt from yt._typing import FieldKey, ParticleCoordinateTuple from yt.geometry.selection_routines import GridSelector @@ -96,7 +97,7 @@ def _read_data(self, grid, field): def _read_fluid_selection( self, chunks, selector, fields: list[FieldKey], size - ) -> Mapping[FieldKey, np.ndarray]: + ) -> Mapping[FieldKey, npt.NDArray]: # This function has an interesting history. It previously was mandate # to be defined by all of the subclasses. But, to avoid having to # rewrite a whole bunch of IO handlers all at once, and to allow a @@ -165,8 +166,8 @@ def _read_particle_data_file(self, data_file, ptf, selector=None): def _read_particle_selection( self, chunks, selector, fields: list[FieldKey] - ) -> dict[FieldKey, np.ndarray]: - data: dict[FieldKey, list[np.ndarray]] = {} + ) -> dict[FieldKey, npt.NDArray]: + data: dict[FieldKey, list[npt.NDArray]] = {} # Initialize containers for tracking particle, field information # ptf (particle field types) maps particle type to list of on-disk fields to read @@ -196,7 +197,7 @@ def _read_particle_selection( for field_f in field_maps[field_r]: data[field_f].append(vals) - rv: dict[FieldKey, np.ndarray] = {} # the return dictionary + rv: dict[FieldKey, npt.NDArray] = {} # the return dictionary fields = list(data.keys()) for field_f in fields: # We need to ensure the arrays have the right shape if there are no diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 46ee88cdbb..e883f7efa3 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -4,6 +4,7 @@ import matplotlib as mpl import numpy as np +import numpy.typing as npt import unyt as un from matplotlib.colors import Colormap, LogNorm, Normalize, SymLogNorm from unyt import unyt_quantity @@ -291,7 +292,7 @@ def linthresh(self, newval: Quantity | float | None) -> None: if newval is not None: self.norm_type = SymLogNorm - def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: + def get_norm(self, data: npt.NDArray, *args, **kw) -> Normalize: if self.norm is not None: return self.norm diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index 8edd33c222..2da25532a6 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import numpy as np +import numpy.typing as npt from yt._maintenance.deprecation import issue_deprecation_warning from yt._typing import FieldKey, MaskT @@ -227,7 +228,7 @@ def render(self, item): # this method exists for clarity of intention return self[item] - def _apply_filters(self, buffer: np.ndarray) -> np.ndarray: + def _apply_filters(self, buffer: npt.NDArray) -> npt.NDArray: for f in self._filters: buffer = f(buffer) return buffer diff --git a/yt/visualization/fixed_resolution_filters.py b/yt/visualization/fixed_resolution_filters.py index 25ca0a0250..b747619105 100644 --- a/yt/visualization/fixed_resolution_filters.py +++ b/yt/visualization/fixed_resolution_filters.py @@ -2,6 +2,7 @@ from functools import update_wrapper, wraps import numpy as np +import numpy.typing as npt from yt._maintenance.deprecation import issue_deprecation_warning from yt.visualization.fixed_resolution import FixedResolutionBuffer @@ -60,10 +61,10 @@ def __init__(self, *args, **kwargs): pass @abstractmethod - def apply(self, buff: np.ndarray) -> np.ndarray: + def apply(self, buff: npt.NDArray) -> npt.NDArray: pass - def __call__(self, buff: np.ndarray) -> np.ndarray: + def __call__(self, buff: npt.NDArray) -> npt.NDArray: # alias to apply return self.apply(buff) diff --git a/yt/visualization/plot_modifications.py b/yt/visualization/plot_modifications.py index 91d7700cf0..3f9badeac5 100644 --- a/yt/visualization/plot_modifications.py +++ b/yt/visualization/plot_modifications.py @@ -9,6 +9,7 @@ import matplotlib import numpy as np +import numpy.typing as npt import rlic from unyt import unyt_quantity @@ -1057,7 +1058,7 @@ def __call__(self, plot) -> None: else: clim = self.clim - levels: np.ndarray | int + levels: npt.NDArray | int if clim is not None: levels = np.linspace(clim[0], clim[1], self.levels) else: diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 9a1bbc9e8e..44c2cd7c04 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -6,6 +6,7 @@ import matplotlib import numpy as np +import numpy.typing as npt from more_itertools import always_iterable from unyt.exceptions import UnitConversionError @@ -1393,7 +1394,7 @@ class NormalPlot: """ @staticmethod - def sanitize_normal_vector(ds, normal) -> str | np.ndarray: + def sanitize_normal_vector(ds, normal) -> str | npt.NDArray: """Return the name of a cartesian axis whener possible, or a 3-element 1D ndarray of float64 in any other valid case. Fail with a descriptive error message otherwise.