Skip to content

Commit 85df546

Browse files
authored
feat: write estimators to json (#44)
* feat: write estimators to json * style: pre-the-commit * refactor: ensure we do not need a subclass * fix: bugs in from_json * fix: need to use X after init * refactor: make things more automatic * feat: use an actual method * fix: private functions
1 parent 167a136 commit 85df546

File tree

4 files changed

+443
-32
lines changed

4 files changed

+443
-32
lines changed

mattspy/json.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
"""Code for numpy arrays from json-numpy under MIT
2+
3+
MIT License
4+
5+
Copyright (c) 2021-2025 Crimson-Crow <[email protected]>
6+
7+
Permission is hereby granted, free of charge, to any person obtaining a copy
8+
of this software and associated documentation files (the "Software"), to deal
9+
in the Software without restriction, including without limitation the rights
10+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
copies of the Software, and to permit persons to whom the Software is
12+
furnished to do so, subject to the following conditions:
13+
14+
The above copyright notice and this permission notice shall be included in all
15+
copies or substantial portions of the Software.
16+
17+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
SOFTWARE.
24+
"""
25+
26+
import os
27+
import json
28+
from base64 import b64decode, b64encode
29+
30+
from numpy import frombuffer, generic, ndarray
31+
from numpy.lib.format import descr_to_dtype, dtype_to_descr
32+
33+
34+
def _hint_tuples(item):
35+
"""See https://stackoverflow.com/a/15721641/1745538"""
36+
if isinstance(item, tuple):
37+
return {"__tuple__": [_hint_tuples(e) for e in item]}
38+
if isinstance(item, list):
39+
return [_hint_tuples(e) for e in item]
40+
if isinstance(item, dict):
41+
return {key: _hint_tuples(value) for key, value in item.items()}
42+
return item
43+
44+
45+
def _dehint_tuples(item):
46+
"""See https://stackoverflow.com/a/15721641/1745538"""
47+
if isinstance(item, tuple):
48+
return tuple([_dehint_tuples(e) for e in item])
49+
if isinstance(item, list):
50+
return [_dehint_tuples(e) for e in item]
51+
if isinstance(item, dict) and "__tuple__" in item:
52+
return tuple([_dehint_tuples(e) for e in item["__tuple__"]])
53+
return item
54+
55+
56+
class _CustomEncoder(json.JSONEncoder):
57+
"""
58+
See https://stackoverflow.com/a/15721641/1745538
59+
"""
60+
61+
def encode(self, obj):
62+
return super().encode(_hint_tuples(obj))
63+
64+
def default(self, o):
65+
from jax import dtypes
66+
import jax.random as jrng
67+
import jax.numpy as jnp
68+
import numpy as np
69+
70+
if isinstance(o, jnp.ndarray) and dtypes.issubdtype(o.dtype, dtypes.prng_key):
71+
o = jrng.key_data(o)
72+
o = np.array(o)
73+
data = o.data if o.flags["C_CONTIGUOUS"] else o.tobytes()
74+
return {
75+
"__jax_rng_key__": b64encode(data).decode(),
76+
"dtype": dtype_to_descr(o.dtype),
77+
"shape": _hint_tuples(o.shape),
78+
}
79+
80+
if isinstance(o, jnp.ndarray):
81+
o = np.array(o)
82+
data = o.data if o.flags["C_CONTIGUOUS"] else o.tobytes()
83+
return {
84+
"__jax__": b64encode(data).decode(),
85+
"dtype": dtype_to_descr(o.dtype),
86+
"shape": _hint_tuples(o.shape),
87+
}
88+
89+
if isinstance(o, (ndarray, generic)):
90+
data = o.data if o.flags["C_CONTIGUOUS"] else o.tobytes()
91+
return {
92+
"__numpy__": b64encode(data).decode(),
93+
"dtype": dtype_to_descr(o.dtype),
94+
"shape": _hint_tuples(o.shape),
95+
}
96+
97+
if isinstance(o, np.random.RandomState):
98+
return {"__numpy_random_state__": _hint_tuples(o.get_state())}
99+
100+
if isinstance(o, np.random.Generator):
101+
return {"__numpy_random_generator__": _hint_tuples(o.bit_generator.state)}
102+
103+
raise TypeError(
104+
f"Object of type {o.__class__.__name__} is not JSON serializable"
105+
)
106+
107+
108+
def _object_hook(dct):
109+
import jax.random as jrng
110+
import jax.numpy as jnp
111+
import numpy as np
112+
113+
if "__jax_rng_key__" in dct:
114+
np_obj = frombuffer(
115+
b64decode(dct["__jax_rng_key__"]), descr_to_dtype(dct["dtype"])
116+
)
117+
arr = (
118+
np_obj.reshape(shape)
119+
if (shape := _dehint_tuples(dct["shape"]))
120+
else np_obj[0]
121+
)
122+
key = jnp.array(arr)
123+
return jrng.wrap_key_data(key)
124+
125+
if "__jax__" in dct:
126+
np_obj = frombuffer(b64decode(dct["__jax__"]), descr_to_dtype(dct["dtype"]))
127+
arr = (
128+
np_obj.reshape(shape)
129+
if (shape := _dehint_tuples(dct["shape"]))
130+
else np_obj[0]
131+
)
132+
return jnp.array(arr)
133+
134+
if "__numpy__" in dct:
135+
np_obj = frombuffer(b64decode(dct["__numpy__"]), descr_to_dtype(dct["dtype"]))
136+
return (
137+
np_obj.reshape(shape)
138+
if (shape := _dehint_tuples(dct["shape"]))
139+
else np_obj[0]
140+
)
141+
142+
if "__tuple__" in dct:
143+
return _dehint_tuples(dct)
144+
145+
if "__numpy_random_state__" in dct:
146+
rng = np.random.RandomState()
147+
rng.set_state(_dehint_tuples(dct["__numpy_random_state__"]))
148+
return rng
149+
150+
if "__numpy_random_generator__" in dct:
151+
data = _dehint_tuples(dct["__numpy_random_generator__"])
152+
bg = getattr(np.random, data["bit_generator"])()
153+
bg.state = data
154+
return np.random.Generator(bg)
155+
156+
return dct
157+
158+
159+
def dump(*args, **kwargs):
160+
return json.dump(*args, cls=_CustomEncoder, **kwargs)
161+
162+
163+
def dumps(*args, **kwargs):
164+
return json.dumps(*args, cls=_CustomEncoder, **kwargs)
165+
166+
167+
def load(*args, **kwargs):
168+
return json.load(*args, object_hook=_object_hook, **kwargs)
169+
170+
171+
def loads(*args, **kwargs):
172+
return json.loads(*args, object_hook=_object_hook, **kwargs)
173+
174+
175+
class EstimatorToFromJSONMixin:
176+
def _init_from_json(self, **data):
177+
for k, v in data.items():
178+
setattr(self, k, v)
179+
180+
def to_json(self, out=None):
181+
"""Serialize this estimator to JSON.
182+
183+
Parameters
184+
----------
185+
out : file-like object, string, or None, optional
186+
If a file-like object or a string, the data is written
187+
using the `write` method, creating / overwriting a file
188+
if a string is given. If None, then only the JSON string
189+
is returned.
190+
191+
Returns
192+
-------
193+
data : str
194+
The JSON-serialized data as a string.
195+
"""
196+
data = {}
197+
for attr in set(self.json_attributes_) | set(self.get_params().keys()):
198+
if hasattr(self, attr):
199+
data[attr] = getattr(self, attr)
200+
data = dumps(data)
201+
202+
if out is None:
203+
pass
204+
elif hasattr(out, "write"):
205+
out.write(data)
206+
else:
207+
with open(out, "w") as fp:
208+
fp.write(data)
209+
210+
return data
211+
212+
@classmethod
213+
def from_json(cls, data):
214+
"""Load an estimator from JSON data.
215+
216+
Parameters
217+
----------
218+
data : str or file-like
219+
The JSON data.
220+
221+
Returns
222+
-------
223+
estimator
224+
"""
225+
if hasattr(data, "read"):
226+
data = load(data)
227+
else:
228+
if os.path.exists(data):
229+
with open(str, "r") as fp:
230+
data = loads(fp.read())
231+
else:
232+
data = loads(data)
233+
234+
obj = cls()
235+
params = {k: data[k] for k in obj.get_params() if k in data}
236+
obj.set_params(**params)
237+
for k in obj.get_params():
238+
if k in data:
239+
del data[k]
240+
241+
obj._init_from_json(**data)
242+
243+
return obj

0 commit comments

Comments
 (0)