diff --git a/mattspy/json.py b/mattspy/json.py new file mode 100644 index 0000000..79dd940 --- /dev/null +++ b/mattspy/json.py @@ -0,0 +1,243 @@ +"""Code for numpy arrays from json-numpy under MIT + +MIT License + +Copyright (c) 2021-2025 Crimson-Crow + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import os +import json +from base64 import b64decode, b64encode + +from numpy import frombuffer, generic, ndarray +from numpy.lib.format import descr_to_dtype, dtype_to_descr + + +def _hint_tuples(item): + """See https://stackoverflow.com/a/15721641/1745538""" + if isinstance(item, tuple): + return {"__tuple__": [_hint_tuples(e) for e in item]} + if isinstance(item, list): + return [_hint_tuples(e) for e in item] + if isinstance(item, dict): + return {key: _hint_tuples(value) for key, value in item.items()} + return item + + +def _dehint_tuples(item): + """See https://stackoverflow.com/a/15721641/1745538""" + if isinstance(item, tuple): + return tuple([_dehint_tuples(e) for e in item]) + if isinstance(item, list): + return [_dehint_tuples(e) for e in item] + if isinstance(item, dict) and "__tuple__" in item: + return tuple([_dehint_tuples(e) for e in item["__tuple__"]]) + return item + + +class _CustomEncoder(json.JSONEncoder): + """ + See https://stackoverflow.com/a/15721641/1745538 + """ + + def encode(self, obj): + return super().encode(_hint_tuples(obj)) + + def default(self, o): + from jax import dtypes + import jax.random as jrng + import jax.numpy as jnp + import numpy as np + + if isinstance(o, jnp.ndarray) and dtypes.issubdtype(o.dtype, dtypes.prng_key): + o = jrng.key_data(o) + o = np.array(o) + data = o.data if o.flags["C_CONTIGUOUS"] else o.tobytes() + return { + "__jax_rng_key__": b64encode(data).decode(), + "dtype": dtype_to_descr(o.dtype), + "shape": _hint_tuples(o.shape), + } + + if isinstance(o, jnp.ndarray): + o = np.array(o) + data = o.data if o.flags["C_CONTIGUOUS"] else o.tobytes() + return { + "__jax__": b64encode(data).decode(), + "dtype": dtype_to_descr(o.dtype), + "shape": _hint_tuples(o.shape), + } + + if isinstance(o, (ndarray, generic)): + data = o.data if o.flags["C_CONTIGUOUS"] else o.tobytes() + return { + "__numpy__": b64encode(data).decode(), + "dtype": dtype_to_descr(o.dtype), + "shape": _hint_tuples(o.shape), + } + + if isinstance(o, np.random.RandomState): + return {"__numpy_random_state__": _hint_tuples(o.get_state())} + + if isinstance(o, np.random.Generator): + return {"__numpy_random_generator__": _hint_tuples(o.bit_generator.state)} + + raise TypeError( + f"Object of type {o.__class__.__name__} is not JSON serializable" + ) + + +def _object_hook(dct): + import jax.random as jrng + import jax.numpy as jnp + import numpy as np + + if "__jax_rng_key__" in dct: + np_obj = frombuffer( + b64decode(dct["__jax_rng_key__"]), descr_to_dtype(dct["dtype"]) + ) + arr = ( + np_obj.reshape(shape) + if (shape := _dehint_tuples(dct["shape"])) + else np_obj[0] + ) + key = jnp.array(arr) + return jrng.wrap_key_data(key) + + if "__jax__" in dct: + np_obj = frombuffer(b64decode(dct["__jax__"]), descr_to_dtype(dct["dtype"])) + arr = ( + np_obj.reshape(shape) + if (shape := _dehint_tuples(dct["shape"])) + else np_obj[0] + ) + return jnp.array(arr) + + if "__numpy__" in dct: + np_obj = frombuffer(b64decode(dct["__numpy__"]), descr_to_dtype(dct["dtype"])) + return ( + np_obj.reshape(shape) + if (shape := _dehint_tuples(dct["shape"])) + else np_obj[0] + ) + + if "__tuple__" in dct: + return _dehint_tuples(dct) + + if "__numpy_random_state__" in dct: + rng = np.random.RandomState() + rng.set_state(_dehint_tuples(dct["__numpy_random_state__"])) + return rng + + if "__numpy_random_generator__" in dct: + data = _dehint_tuples(dct["__numpy_random_generator__"]) + bg = getattr(np.random, data["bit_generator"])() + bg.state = data + return np.random.Generator(bg) + + return dct + + +def dump(*args, **kwargs): + return json.dump(*args, cls=_CustomEncoder, **kwargs) + + +def dumps(*args, **kwargs): + return json.dumps(*args, cls=_CustomEncoder, **kwargs) + + +def load(*args, **kwargs): + return json.load(*args, object_hook=_object_hook, **kwargs) + + +def loads(*args, **kwargs): + return json.loads(*args, object_hook=_object_hook, **kwargs) + + +class EstimatorToFromJSONMixin: + def _init_from_json(self, **data): + for k, v in data.items(): + setattr(self, k, v) + + def to_json(self, out=None): + """Serialize this estimator to JSON. + + Parameters + ---------- + out : file-like object, string, or None, optional + If a file-like object or a string, the data is written + using the `write` method, creating / overwriting a file + if a string is given. If None, then only the JSON string + is returned. + + Returns + ------- + data : str + The JSON-serialized data as a string. + """ + data = {} + for attr in set(self.json_attributes_) | set(self.get_params().keys()): + if hasattr(self, attr): + data[attr] = getattr(self, attr) + data = dumps(data) + + if out is None: + pass + elif hasattr(out, "write"): + out.write(data) + else: + with open(out, "w") as fp: + fp.write(data) + + return data + + @classmethod + def from_json(cls, data): + """Load an estimator from JSON data. + + Parameters + ---------- + data : str or file-like + The JSON data. + + Returns + ------- + estimator + """ + if hasattr(data, "read"): + data = load(data) + else: + if os.path.exists(data): + with open(str, "r") as fp: + data = loads(fp.read()) + else: + data = loads(data) + + obj = cls() + params = {k: data[k] for k in obj.get_params() if k in data} + obj.set_params(**params) + for k in obj.get_params(): + if k in data: + del data[k] + + obj._init_from_json(**data) + + return obj diff --git a/mattspy/som/_jax_impl.py b/mattspy/som/_jax_impl.py index 7b9984c..49ff98f 100644 --- a/mattspy/som/_jax_impl.py +++ b/mattspy/som/_jax_impl.py @@ -1,5 +1,6 @@ import jax from jax import numpy as jnp +import numpy as np import optax from sklearn.base import BaseEstimator, ClusterMixin @@ -7,6 +8,8 @@ from sklearn.utils.validation import validate_data from sklearn.exceptions import NotFittedError +from mattspy.json import EstimatorToFromJSONMixin + @jax.jit def _jax_predict_som(weights, X): @@ -63,11 +66,14 @@ def _jax_compute_extended_distortion(weights, wpos, X, sigma): ) -class SOMap(ClusterMixin, BaseEstimator): +class SOMap(EstimatorToFromJSONMixin, ClusterMixin, BaseEstimator): """A mini-batch Self-organazing Map (SOM) implementation. - This SOM implementation fits the data through a mini-batch technique. - The mini-batch technique is based on extending the adaptive mini-batch + This SOM implementation fits the data through a mini-batch technique + based on either using a custom 'online' optimizer or directly minimizing + the Extended Distortion. + + The 'online' mini-batch technique is based on extending the adaptive mini-batch K-means algorithm from Sculley (2010, "Web-Scale K-Means Clustering") to SOMs using gradients of the Extended Distortion (Ritter et al., 1992, "Neural Computation and Self-Organizing Maps: an Introduction"). @@ -124,6 +130,18 @@ class SOMap(ClusterMixin, BaseEstimator): Set to True if the fit converged. False otherwise. """ + json_attributes_ = ( + "_is_fit", + "_rng", + "_jax_rng_key", + "n_seen_", + "n_weight_grid_", + "n_iter_", + "converged_", + "weights_", + "weight_positions_", + ) + def __init__( self, n_clusters=16, @@ -179,33 +197,39 @@ def partial_fit(self, X, y=None): """ return self._partial_fit(1, X) - def _init_numpy(self, X): - X = validate_data(self, X=X, reset=True) - return X + def _init_from_json(self, X=None, **kwargs): + if X is None and "weights_" in kwargs: + X = np.ones((1, kwargs["weights_"].shape[1])) - def _init_jax(self, X): - self.n_features_in_ = X.shape[1] - return X + self.n_seen_ = kwargs.get( + "n_seen_", + jnp.zeros(self.n_clusters), + ) + self.n_weight_grid_ = kwargs.get( + "n_weight_grid_", + int(np.ceil(np.sqrt(self.n_clusters))), + ) + self.n_iter_ = kwargs.get("n_iter_", 0) - def _partial_fit(self, n_epochs, X, y=None): - if not getattr(self, "_is_fit", False): - self.n_seen_ = jnp.zeros(self.n_clusters) - self.n_weight_grid_ = int(jnp.ceil(jnp.sqrt(self.n_clusters))) - self.n_iter_ = 0 + self._rng = kwargs.get("_rng", check_random_state(self.random_state)) - # rng init - self._rng = check_random_state(self.random_state) + if "_jax_rng_key" in kwargs: + self._jax_rng_key = kwargs["_jax_rng_key"] + else: self._jax_rng_key = jax.random.key( self._rng.randint(low=1, high=int(2**31)) ) - # check inputs and convert to JAX - if not isinstance(X, jnp.ndarray): - X = self._init_numpy(X) - X = jnp.array(X) - else: - X = self._init_jax(X) + # check inputs and convert to JAX + if not isinstance(X, jnp.ndarray): + X = validate_data(self, X=X, reset=True) + X = jnp.array(X) + else: + validate_data(self, X=np.ones((1, X.shape[1])), reset=True) + if "weights_" in kwargs: + self.weights_ = jnp.array(kwargs["weights_"]) + else: # weight init self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key) if X.shape[1] == 1: @@ -241,6 +265,9 @@ def _partial_fit(self, n_epochs, X, y=None): + eigscale[:, 1:2] * sqrt_eval2 * evec2 ) + if "weight_positions_" in kwargs: + self.weight_positions_ = jnp.array(kwargs["weight_positions_"]) + else: # weight position init pos = jnp.linspace(0, 1, self.n_weight_grid_) xp, yp = jnp.meshgrid(pos, pos) @@ -256,6 +283,15 @@ def _partial_fit(self, n_epochs, X, y=None): shape=(self.n_clusters,), ) self.weight_positions_ = self.weight_positions_[rind, :] + + self.converged_ = kwargs.get("converged_", False) + self._is_fit = kwargs.get("_is_fit", True) + + return X + + def _partial_fit(self, n_epochs, X, y=None): + if not getattr(self, "_is_fit", False): + X = self._init_from_json(X) else: if not isinstance(X, jnp.ndarray): X = validate_data(self, X=X, reset=False) @@ -267,7 +303,7 @@ def _partial_fit(self, n_epochs, X, y=None): opt_state = optimizer.init(self.weights_) dw = 1.0 / self.n_weight_grid_ - sigma_frac_dw = jnp.maximum(1.0, self.sigma_frac / dw) + sigma_frac_dw = np.maximum(1.0, self.sigma_frac / dw) converged = False for _ in range(n_epochs): @@ -276,8 +312,8 @@ def _partial_fit(self, n_epochs, X, y=None): self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key) inds = jax.random.permutation(subkey, X.shape[0]) - _sigma_frac = dw * jnp.power( - sigma_frac_dw, jnp.maximum(1.0 - self.n_iter_ / self.max_iter, 0.0) + _sigma_frac = dw * np.power( + sigma_frac_dw, np.maximum(1.0 - self.n_iter_ / self.max_iter, 0.0) ) for start in range(0, X.shape[0], self.batch_size): @@ -314,7 +350,6 @@ def _partial_fit(self, n_epochs, X, y=None): break self.converged_ = converged - self._is_fit = True self.labels_ = _jax_predict_som(self.weights_, X) return self diff --git a/mattspy/som/tests/test_jax_impl.py b/mattspy/som/tests/test_jax_impl.py index 7f4edc8..93b82f2 100644 --- a/mattspy/som/tests/test_jax_impl.py +++ b/mattspy/som/tests/test_jax_impl.py @@ -79,10 +79,42 @@ def test_som_random_state_handling(with_jax, clst): assert np.allclose(labels, labels_again) -def _apply_label_mapping(y, labels, n_clusters): - from scipy.stats import mode +def test_som_to_from_json_fit(clst): + X, y = load_iris(return_X_y=True) + clst.fit(X) + labels = clst.predict(X) + est_json = clst.to_json() + ml = _mode_label(y, clst.labels_, clst.n_clusters) + assert np.array_equal(np.sort(ml), np.arange(clst.n_clusters)) - vals = {} - for k in range(n_clusters): - vals[k] = mode(y[labels == k])[0] - return jnp.array([vals[int(ll)] for ll in labels]) + print(est_json) + + new_clst = SOMap.from_json(est_json) + assert est_json == new_clst.to_json() + assert jnp.array_equal(clst.weights_, new_clst.weights_) + new_labels = new_clst.predict(X) + assert jnp.allclose(labels, new_labels) + + new_clst.fit(X) + assert jnp.array_equal(clst.weights_, new_clst.weights_) + new_fit_labels = new_clst.predict(X) + assert jnp.allclose(labels, new_fit_labels) + + +def test_som_to_from_json_partial_fit(clst): + X, y = load_iris(return_X_y=True) + clst.partial_fit(X) + new_clst = SOMap.from_json(clst.to_json()) + for _ in range(399): + clst.partial_fit(X) + new_clst.partial_fit(X) + + new_labels = new_clst.predict(X) + labels = clst.predict(X) + assert jnp.allclose(labels, new_labels) + + ml = _mode_label(y, labels, clst.n_clusters) + assert np.array_equal(np.sort(ml), np.arange(clst.n_clusters)) + + assert jnp.array_equal(clst.weights_, new_clst.weights_) + assert jnp.array_equal(clst.n_features_in_, new_clst.n_features_in_) diff --git a/mattspy/tests/test_json.py b/mattspy/tests/test_json.py new file mode 100644 index 0000000..606c6b0 --- /dev/null +++ b/mattspy/tests/test_json.py @@ -0,0 +1,101 @@ +import numpy as np +import jax.numpy as jnp +import jax.random as jrng + +import pytest + +from mattspy.json import dumps, loads + + +@pytest.mark.parametrize( + "val", + [ + np.array(10.0), + np.array(10.0, dtype=int), + np.array(10.0, dtype=float), + np.array(10.0, dtype=np.float32), + np.array(10.0, dtype=np.int32), + np.array(10.0, dtype=np.float64), + np.array(10.0, dtype=np.int64), + np.array("342524", dtype="U"), + np.array("dsfcsda", dtype="S"), + np.arange(10, dtype=int), + np.arange(10, dtype=np.int32), + np.arange(10, dtype=np.int64), + np.arange(10, dtype=np.float64), + np.arange(10, dtype=np.float32), + np.array([10, np.nan, np.inf], dtype=np.float32), + np.array([10, np.nan, np.inf], dtype=np.float64), + np.arange(10, dtype=np.complex64), + np.arange(10, dtype=np.complex128), + np.arange(10, dtype=float), + np.array(["%s" % i for i in range(10)], dtype="U"), + np.array(["%s" % i for i in range(10)], dtype="S"), + ], +) +def test_json_numpy(val): + sval = loads(dumps([val]))[0] + if "U" not in val.dtype.descr[0][1] and "S" not in val.dtype.descr[0][1]: + assert np.array_equal(val, sval, equal_nan=True) + else: + assert np.array_equal(val, sval) + assert val.shape == sval.shape + assert val.dtype == sval.dtype + + +@pytest.mark.parametrize( + "val", + [ + jnp.array(10.0), + jnp.array(10.0, dtype=int), + jnp.array(10.0, dtype=float), + jnp.array(10.0, dtype=jnp.float32), + jnp.array(10.0, dtype=jnp.int32), + jnp.array(10.0, dtype=jnp.float64), + jnp.array(10.0, dtype=jnp.int64), + jnp.arange(10, dtype=int), + jnp.arange(10, dtype=jnp.int32), + jnp.arange(10, dtype=jnp.int64), + jnp.arange(10, dtype=jnp.float64), + jnp.arange(10, dtype=jnp.float32), + jnp.array([10, jnp.nan, jnp.inf], dtype=jnp.float32), + jnp.array([10, jnp.nan, jnp.inf], dtype=jnp.float64), + jnp.arange(10, dtype=jnp.complex64), + jnp.arange(10, dtype=jnp.complex128), + jnp.arange(10, dtype=float), + ], +) +def test_json_jax(val): + sval = loads(dumps([val]))[0] + if "U" not in val.dtype.descr[0][1] and "S" not in val.dtype.descr[0][1]: + assert jnp.array_equal(val, sval, equal_nan=True) + else: + assert jnp.array_equal(val, sval) + assert val.shape == sval.shape + assert val.dtype == sval.dtype + + +def test_json_numpy_random(): + rng = np.random.RandomState(seed=10) + srng = loads(dumps([rng]))[0] + assert rng.normal() == srng.normal() + + rng = np.random.default_rng(seed=10) + srng = loads(dumps([rng]))[0] + assert rng.normal() == srng.normal() + + rng = np.random.Generator(np.random.MT19937(seed=10)) + srng = loads(dumps([rng]))[0] + assert rng.normal() == srng.normal() + + +def test_json_tuple(): + val = (10, (10, 4, 5.0, {"a": 10})) + sval = loads(dumps(val)) + assert val == sval + + +def test_json_jax_random(): + key = jrng.key(seed=100) + skey = loads(dumps([key]))[0] + assert jrng.normal(key) == jrng.normal(skey)