Skip to content

Commit de9cd5d

Browse files
authored
update indexing in quickstart (#115)
followup on #114, fix some things and update the quickstart to demo the new indexing syntax this allows
1 parent a480e96 commit de9cd5d

File tree

7 files changed

+235
-204
lines changed

7 files changed

+235
-204
lines changed

docs/examples/quickstart.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828

2929
# check CHD
3030
assert chd.data["head"][0, 0].item() == 1.0
31-
assert chd.data["head"][0, 99].item() == 0.0
32-
assert np.allclose(chd.data["head"][:, 1:99].data.todense(), np.full(98, 1e30))
31+
assert chd.data.head.sel(per=0)[99].item() == 0.0
32+
assert np.allclose(chd.data.head[:, 1:99], np.full(98, 1e30))
3333

34-
# TODO: xarray index aliasing nlay/ncol/nrow to k/i/j?
35-
# assert chd.data["head"].loc(dict(k=0, i=0, j=0)) == 1.
36-
# assert chd.data["head"].loc(dict(k=0, i=9, j=9)) == 0.
34+
# check DIS
35+
assert dis.data.botm.sel(lay=0, col=0, row=0) == 0.0
3736

3837
# check OC
39-
assert oc.data["save_head"][0].item() == "all"
40-
assert oc.data["save_budget"][0].item() == "all"
38+
assert oc.data["save_head"][0] == "all"
39+
assert oc.data.save_head.sel(per=0) == "all"

flopy4/mf6/gwf/dis.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class Dis(Package):
2222
default=False, metadata={"block": "options"}
2323
)
2424
nlay: int = dim(
25+
# disable the otherwise automatic coordinate variable
26+
# because we're going to create another one for this
27+
# dimension with a different name via a custom index
2528
coord=False,
2629
scope="gwf",
2730
default=1,
@@ -46,36 +49,40 @@ class Dis(Package):
4649
},
4750
)
4851
delr: NDArray[np.floating] = array(
49-
dims=("ncol",),
5052
default=1.0,
53+
dims=("ncol",),
5154
metadata={"block": "griddata"},
5255
converter=Converter(convert_array, takes_self=True, takes_field=True),
5356
)
5457
delc: NDArray[np.floating] = array(
55-
dims=("nrow",),
5658
default=1.0,
59+
dims=("nrow",),
5760
metadata={"block": "griddata"},
5861
converter=Converter(convert_array, takes_self=True, takes_field=True),
5962
)
6063
top: NDArray[np.floating] = array(
61-
dims=("ncol", "nrow"),
6264
default=1.0,
65+
dims=("ncol", "nrow"),
6366
metadata={"block": "griddata"},
6467
converter=Converter(convert_array, takes_self=True, takes_field=True),
6568
)
6669
botm: NDArray[np.floating] = array(
67-
dims=("ncol", "nrow", "nlay"),
6870
default=0.0,
71+
dims=("ncol", "nrow", "nlay"),
6972
metadata={"block": "griddata"},
7073
converter=Converter(convert_array, takes_self=True, takes_field=True),
7174
)
7275
idomain: NDArray[np.integer] = array(
73-
dims=("ncol", "nrow", "nlay"),
7476
default=1,
77+
dims=("ncol", "nrow", "nlay"),
7578
metadata={"block": "griddata"},
7679
converter=Converter(convert_array, takes_self=True, takes_field=True),
7780
)
78-
nnodes: int = dim(scope="gwf", init=False)
81+
nnodes: int = dim(
82+
# coord=False,
83+
scope="gwf",
84+
init=False,
85+
)
7986

8087
def __attrs_post_init__(self):
8188
self.nnodes = self.ncol * self.nrow * self.nlay

flopy4/mf6/gwf/oc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,31 @@ class Period:
5353
format: Optional[Format] = field(
5454
default=None, init=False, metadata={"block": "options"}
5555
)
56-
save_head: Optional[NDArray[np.str_ | np.integer]] = array(
57-
dims=("nper",),
56+
save_head: Optional[NDArray[np.object_]] = array(
57+
Steps,
5858
default="all",
59+
dims=("nper",),
5960
metadata={"block": "perioddata"},
6061
converter=Converter(convert_array, takes_self=True, takes_field=True),
6162
)
62-
save_budget: Optional[NDArray[np.str_ | np.integer]] = array(
63-
dims=("nper",),
63+
save_budget: Optional[NDArray[np.object_]] = array(
64+
Steps,
6465
default="all",
66+
dims=("nper",),
6567
metadata={"block": "perioddata"},
6668
converter=Converter(convert_array, takes_self=True, takes_field=True),
6769
)
68-
print_head: Optional[NDArray[np.str_ | np.integer]] = array(
69-
dims=("nper",),
70+
print_head: Optional[NDArray[np.object_]] = array(
71+
Steps,
7072
default="all",
73+
dims=("nper",),
7174
metadata={"block": "perioddata"},
7275
converter=Converter(convert_array, takes_self=True, takes_field=True),
7376
)
74-
print_budget: Optional[NDArray[np.str_ | np.integer]] = array(
75-
dims=("nper",),
77+
print_budget: Optional[NDArray[np.object_]] = array(
78+
Steps,
7679
default="all",
80+
dims=("nper",),
7781
metadata={"block": "perioddata"},
7882
converter=Converter(convert_array, takes_self=True, takes_field=True),
7983
)

flopy4/mf6/indexes.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,23 @@
55

66

77
def alias(dataset: xr.Dataset, old_name: str, new_name: str) -> PandasIndex:
8-
"""Alias a dimension coordinate to a coordinate with a different name."""
8+
"""
9+
Alias a dimension coordinate to a coordinate with a different name.
10+
Suggested by Benoit Bovy https://github.com/pydata/xarray/pull/10076#issuecomment-2809041994.
11+
"""
912
try:
1013
size = dataset.sizes[old_name]
1114
except KeyError:
12-
try:
13-
size = dataset.dims[old_name]
14-
except KeyError:
15-
size = dataset.attrs[old_name]
15+
size = dataset.attrs[old_name]
1616
return PandasIndex(pd.RangeIndex(size, name=new_name), dim=old_name)
1717

1818

19-
class GridIndex(Index):
19+
class MetaIndex(Index):
20+
"""
21+
Combine multiple indexes into a single index.
22+
Adapted from https://docs.xarray.dev/en/stable/internals/how-to-create-custom-index.html#meta-indexes.
23+
"""
24+
2025
def __init__(self, indices):
2126
self._indices = indices
2227

@@ -39,13 +44,29 @@ def sel(self, labels):
3944
results.append(index.sel({k: labels[k]}))
4045
return merge_sel_results(results)
4146

47+
def to_pandas_index(self) -> pd.Index:
48+
# from https://github.com/corteva/rioxarray/pull/846/files#diff-917105823f61e63ef4afde8bed408a6c249e375690e56bc800406676f02551d8R418
49+
if len(self._indices) == 1:
50+
index = next(iter(self._indices.values()))
51+
if isinstance(index, PandasIndex):
52+
return index.to_pandas_index()
53+
54+
raise ValueError("Cannot convert MetaIndex to pandas.Index")
55+
4256

43-
def grid_index(dataset: xr.Dataset) -> GridIndex:
44-
return GridIndex(
57+
def grid_index(dataset: xr.Dataset) -> MetaIndex:
58+
return MetaIndex(
4559
{
46-
# k collides with npf.k so use "lay"
60+
# TODO add 'per' (stress period)
4761
"lay": alias(dataset, "nlay", "lay"),
4862
"col": alias(dataset, "ncol", "col"),
4963
"row": alias(dataset, "nrow", "row"),
64+
# "node": alias(dataset, "nnodes", "node"),
65+
# TODO: adding node breaks the other three.
66+
# and just having node by itself works. why?
5067
}
5168
)
69+
70+
71+
def time_index(dataset: xr.Dataset) -> PandasIndex:
72+
return alias(dataset, "nper", "per")

flopy4/mf6/tdis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from xattree import ROOT, array, dim, field, xattree
88

99
from flopy4.mf6.converters import convert_array
10+
from flopy4.mf6.indexes import time_index
1011
from flopy4.mf6.package import Package
1112

1213

13-
@xattree
14+
@xattree(index=time_index, index_scope=ROOT)
1415
class Tdis(Package):
1516
@define
1617
class PeriodData:
@@ -19,6 +20,7 @@ class PeriodData:
1920
tsmult: float
2021

2122
nper: int = dim(
23+
coord=False,
2224
default=1,
2325
scope=ROOT,
2426
metadata={"block": "dimensions"},

0 commit comments

Comments
 (0)