Skip to content

Commit 3ec786c

Browse files
committed
clean up little more
1 parent ee7b6d3 commit 3ec786c

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

xarray/core/dataset.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9088,7 +9088,9 @@ def polyfit(
90889088
"""
90899089
from xarray.core.dataarray import DataArray
90909090

9091-
variables = {}
9091+
# TODO: This can be narrowed to be Variable only if we figure out how to
9092+
# handle the coordinate values for singular values
9093+
variables: dict[Hashable, DataArray | Variable] = {}
90929094
skipna_da = skipna
90939095

90949096
x = np.asarray(_ensure_numeric(self.coords[dim]).astype(np.float64))
@@ -9120,16 +9122,16 @@ def polyfit(
91209122
rank = np.linalg.matrix_rank(lhs)
91219123

91229124
if full:
9123-
rank = DataArray(rank, name=xname + "matrix_rank")
9124-
variables[rank.name] = rank
9125+
rank = Variable(dims=(), data=rank)
9126+
variables[xname + "matrix_rank"] = rank
91259127
_sing = np.linalg.svd(lhs, compute_uv=False)
9128+
# Using a DataArray here because `degree_dim` coordinate values need not
91269129
sing = DataArray(
91279130
_sing,
91289131
dims=(degree_dim,),
91299132
coords={degree_dim: np.arange(rank - 1, -1, -1)},
9130-
name=xname + "singular_values",
91319133
)
9132-
variables[sing.name] = sing
9134+
variables[xname + "singular_values"] = sing
91339135

91349136
# If we have a coordinate get its underlying dimension.
91359137
(true_dim,) = self.coords[dim].dims
@@ -9184,29 +9186,32 @@ def polyfit(
91849186
# Thus a ReprObject => polyfit was called on a DataArray
91859187
name = ""
91869188

9187-
coeffs = Variable(data=coeffs / scale_da, dims=(degree_dim,) + other_dims)
9188-
variables[name + "polyfit_coefficients"] = coeffs
9189+
variables[name + "polyfit_coefficients"] = Variable(
9190+
data=coeffs / scale_da, dims=(degree_dim,) + other_dims
9191+
)
91899192

91909193
if full or (cov is True):
9191-
residuals = Variable(
9194+
variables[name + "polyfit_residuals"] = Variable(
91929195
data=residuals if var.ndim > 1 else residuals.squeeze(),
91939196
dims=other_dims,
91949197
)
9195-
variables[name + "polyfit_residuals"] = residuals
91969198

91979199
if cov:
91989200
Vbase = np.linalg.inv(np.dot(lhs.T, lhs))
91999201
Vbase /= np.outer(scale, scale)
9202+
if TYPE_CHECKING:
9203+
fac: int | DataArray | Variable
92009204
if cov == "unscaled":
92019205
fac = 1
92029206
else:
92039207
if x.shape[0] <= order:
92049208
raise ValueError(
92059209
"The number of data points must exceed order to scale the covariance matrix."
92069210
)
9207-
fac = residuals / (x.shape[0] - order)
9208-
covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac
9209-
variables[name + "polyfit_covariance"] = covariance
9211+
fac = variables[name + "polyfit_residuals"] / (x.shape[0] - order)
9212+
variables[name + "polyfit_covariance"] = (
9213+
Variable(data=Vbase, dims=("cov_i", "cov_j")) * fac
9214+
)
92109215

92119216
return type(self)(
92129217
data_vars=variables,

0 commit comments

Comments
 (0)