Skip to content

Commit be2f10f

Browse files
committed
feat: fm to json
1 parent 85df546 commit be2f10f

File tree

4 files changed

+232
-99
lines changed

4 files changed

+232
-99
lines changed

mattspy/fm/_jax_impl.py

Lines changed: 129 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from jax import numpy as jnp
44
from jax.tree_util import Partial as partial
55

6+
import numpy as np
67
from optax.losses import softmax_cross_entropy_with_integer_labels
78
from sklearn.base import ClassifierMixin, BaseEstimator
89
from sklearn.utils import check_random_state
@@ -11,6 +12,8 @@
1112
from sklearn.exceptions import NotFittedError
1213
from sklearn.preprocessing import LabelEncoder
1314

15+
from mattspy.json import EstimatorToFromJSONMixin
16+
1417

1518
@jax.jit
1619
def _lowrank_twoway_term(x, vmat):
@@ -121,7 +124,11 @@ def _call_in_batches_maybe(self, func, X):
121124
return func(self.params_, X)
122125

123126

124-
class FMClassifier(ClassifierMixin, BaseEstimator):
127+
class _LabelEncoder(EstimatorToFromJSONMixin, LabelEncoder):
128+
json_attributes_ = ("classes_",)
129+
130+
131+
class FMClassifier(EstimatorToFromJSONMixin, ClassifierMixin, BaseEstimator):
125132
r"""A Factorization Machine classifier.
126133
127134
The FM model for the logits for class c is
@@ -138,7 +145,7 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
138145
batch_size : int, optional
139146
The number of examples to use when fitting the estimator
140147
and making predictions. The value None indicates to use all
141-
examples.
148+
examples. This parameter is ignored if the solver is set to `lbfgs`.
142149
lambda_v : float, optional
143150
The L2 regularization strength to use for the low-rank embedding
144151
matrix.
@@ -172,6 +179,18 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
172179
otherwise.
173180
"""
174181

182+
json_attributes_ = (
183+
"_is_fit",
184+
"_rng",
185+
"_jax_rng_key",
186+
"classes_",
187+
"n_classes_",
188+
"params_",
189+
"converged_",
190+
"n_iter_",
191+
"_label_encoder",
192+
)
193+
175194
def __init__(
176195
self,
177196
rank=8,
@@ -202,7 +221,10 @@ def __init__(
202221

203222
def fit(self, X, y):
204223
self._is_fit = False
205-
return self.partial_fit(X, y)
224+
return self._partial_fit(self.max_iter, X, y)
225+
226+
def partial_fit(self, X, y, classes=None):
227+
return self._partial_fit(1, X, y, classes=classes)
206228

207229
def _init_numpy(self, X, y, classes=None):
208230
X, y = validate_data(self, X=X, y=y, reset=True)
@@ -215,9 +237,9 @@ def _init_numpy(self, X, y, classes=None):
215237
)
216238

217239
if classes is not None:
218-
self._label_encoder = LabelEncoder().fit(classes)
240+
self._label_encoder = _LabelEncoder().fit(classes)
219241
else:
220-
self._label_encoder = LabelEncoder().fit(y)
242+
self._label_encoder = _LabelEncoder().fit(y)
221243
self.classes_ = self._label_encoder.classes_
222244
self.n_classes_ = len(self.classes_)
223245
return X, y
@@ -229,7 +251,13 @@ def _init_jax(self, X, y, classes=None):
229251
else:
230252
self.classes_ = jnp.unique(y)
231253
self.n_classes_ = len(self.classes_)
232-
self.n_features_in_ = X.shape[1]
254+
255+
validate_data(
256+
self,
257+
X=np.ones((1, X.shape[1])),
258+
y=np.ones(1, dtype=np.int32),
259+
reset=True,
260+
)
233261

234262
if not jnp.array_equal(jnp.arange(self.n_classes_), self.classes_):
235263
raise ValueError(
@@ -239,28 +267,36 @@ def _init_jax(self, X, y, classes=None):
239267

240268
return X, y
241269

242-
def partial_fit(self, X, y, classes=None):
243-
if not getattr(self, "_is_fit", False):
244-
self._rng = check_random_state(self.random_state)
270+
def _init_from_json(self, X=None, y=None, classes=None, **kwargs):
271+
self.n_iter_ = kwargs.get("n_iter_", 0)
272+
self._rng = kwargs.get("_rng", check_random_state(self.random_state))
273+
if "_jax_rng_key" in kwargs:
274+
self._jax_rng_key = kwargs["_jax_rng_key"]
275+
else:
245276
self._jax_rng_key = jax.random.key(
246277
self._rng.randint(low=1, high=int(2**31))
247278
)
279+
self.converged_ = kwargs.get(
280+
"converged_",
281+
False,
282+
)
283+
self._is_fit = kwargs.get("_is_fit", True)
284+
285+
if X is None and y is None:
286+
# restore strictly from JSON
287+
if "classes_" in kwargs:
288+
self.classes_ = kwargs["classes_"]
289+
if "n_classes_" in kwargs:
290+
self.n_classes_ = kwargs["n_classes_"]
291+
if "_label_encoder" in kwargs:
292+
self._label_encoder = kwargs["_label_encoder"]
293+
else:
248294
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
249295
X, y = self._init_numpy(X, y, classes=classes)
250296
else:
251297
X, y = self._init_jax(X, y, classes=classes)
252-
else:
253-
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
254-
X, y = validate_data(self, X=X, y=y, reset=False)
255-
else:
256-
y = jnp.rint(y).astype(jnp.int32)
257298

258-
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
259-
y = self._label_encoder.transform(y)
260-
X = jnp.array(X)
261-
y = jnp.array(y)
262-
263-
if not getattr(self, "_is_fit", False):
299+
if "params_" not in kwargs:
264300
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
265301
w0 = jax.random.normal(subkey, shape=(self.n_classes_))
266302
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
@@ -269,52 +305,70 @@ def partial_fit(self, X, y, classes=None):
269305
vmat = jax.random.normal(
270306
subkey, shape=(self.n_features_in_, self.rank, self.n_classes_)
271307
)
272-
params = (w0, w, vmat)
308+
self.params_ = (w0, w, vmat)
273309
else:
274-
params = tuple(p.copy() for p in self.params_)
310+
self.params_ = kwargs["params_"]
275311

276-
kwargs = {k: v for k, v in (self.solver_kwargs or tuple())}
277-
optimizer = getattr(optax, self.solver)(**kwargs)
278-
opt_state = optimizer.init(params)
279-
converged = False
280-
281-
if self.batch_size is not None:
282-
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
283-
inds = jax.random.permutation(subkey, X.shape[0])
284-
for start in range(0, X.shape[0], self.batch_size):
285-
end = min(start + self.batch_size, X.shape[0])
286-
Xb = X[inds[start:end], :]
287-
yb = y[inds[start:end]]
288-
grads = _grad_jax_loss_func(
289-
params, Xb, yb, self.lambda_v, self.lambda_w
290-
)
291-
updates, opt_state = optimizer.update(grads, opt_state, params)
292-
params = optax.apply_updates(params, updates)
312+
return X, y
293313

294-
self.n_iter_ = 1
314+
def _partial_fit(self, n_epochs, X, y, classes=None):
315+
if not getattr(self, "_is_fit", False):
316+
X, y = self._init_from_json(X=X, y=y, classes=classes)
295317
else:
296-
new_value = None
297-
for i in range(self.max_iter):
298-
value = new_value
299-
300-
if self.solver in ["lbfgs"]:
301-
new_value, grads = _value_and_grad_from_state_jax_loss_func(
302-
params,
303-
X,
304-
y,
305-
self.lambda_v,
306-
self.lambda_w,
307-
state=opt_state,
308-
)
318+
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
319+
X, y = validate_data(self, X=X, y=y, reset=False)
320+
else:
321+
y = jnp.rint(y).astype(jnp.int32)
322+
323+
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
324+
y = self._label_encoder.transform(y)
325+
X = jnp.array(X)
326+
y = jnp.array(y)
327+
328+
kwargs = {k: v for k, v in (self.solver_kwargs or tuple())}
329+
optimizer = getattr(optax, self.solver)(**kwargs)
330+
opt_state = optimizer.init(self.params_)
331+
new_value = None
332+
333+
for _ in range(n_epochs):
334+
value = new_value
335+
336+
if self.solver not in ["lbfgs"]:
337+
if self.batch_size is not None:
338+
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
339+
inds = jax.random.permutation(subkey, X.shape[0])
340+
for start in range(0, X.shape[0], self.batch_size):
341+
end = min(start + self.batch_size, X.shape[0])
342+
Xb = X[inds[start:end], :]
343+
yb = y[inds[start:end]]
344+
grads = _grad_jax_loss_func(
345+
self.params_, Xb, yb, self.lambda_v, self.lambda_w
346+
)
347+
updates, opt_state = optimizer.update(
348+
grads, opt_state, self.params_
349+
)
350+
new_params = optax.apply_updates(self.params_, updates)
309351
else:
310-
new_value, grads = _value_and_grad_jax_loss_func(
311-
params, X, y, self.lambda_v, self.lambda_w
352+
grads = _grad_jax_loss_func(
353+
self.params_, X, y, self.lambda_v, self.lambda_w
312354
)
313-
355+
updates, opt_state = optimizer.update(
356+
grads, opt_state, self.params_
357+
)
358+
new_params = optax.apply_updates(self.params_, updates)
359+
else:
360+
new_value, grads = _value_and_grad_from_state_jax_loss_func(
361+
self.params_,
362+
X,
363+
y,
364+
self.lambda_v,
365+
self.lambda_w,
366+
state=opt_state,
367+
)
314368
updates, opt_state = optimizer.update(
315369
grads,
316370
opt_state,
317-
params,
371+
self.params_,
318372
value=new_value,
319373
grad=grads,
320374
value_fn=partial(
@@ -325,26 +379,25 @@ def partial_fit(self, X, y, classes=None):
325379
lambda_w=self.lambda_w,
326380
),
327381
)
328-
new_params = optax.apply_updates(params, updates)
329-
330-
if i > 0 and (
331-
all(
332-
[
333-
jnp.allclose(new_p, p, atol=self.atol, rtol=self.rtol)
334-
for new_p, p in zip(new_params, params)
335-
]
336-
)
337-
or jnp.allclose(value, new_value, atol=self.atol, rtol=self.rtol)
338-
):
339-
converged = True
340-
break
341-
params = new_params
342-
343-
self.n_iter_ = i
382+
new_params = optax.apply_updates(self.params_, updates)
383+
384+
self.n_iter_ += 1
385+
if self.n_iter_ > 1 and (
386+
all(
387+
[
388+
jnp.allclose(new_p, p, atol=self.atol, rtol=self.rtol)
389+
for new_p, p in zip(new_params, self.params_)
390+
]
391+
)
392+
or (
393+
self.solver in ["lbfgs"]
394+
and jnp.allclose(value, new_value, atol=self.atol, rtol=self.rtol)
395+
)
396+
):
397+
self.converged_ = True
398+
break
344399

345-
self.params_ = params
346-
self.converged_ = converged
347-
self._is_fit = True
400+
self.params_ = new_params
348401

349402
return self
350403

0 commit comments

Comments
 (0)