Skip to content

Commit 3627360

Browse files
authored
feat: fm to json, docs (#46)
* feat: fm to json * fix: more iter for auc * fix: docs, remove print
1 parent 85df546 commit 3627360

File tree

5 files changed

+329
-102
lines changed

5 files changed

+329
-102
lines changed

mattspy/fm/_jax_impl.py

Lines changed: 203 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from jax import numpy as jnp
44
from jax.tree_util import Partial as partial
55

6+
import numpy as np
67
from optax.losses import softmax_cross_entropy_with_integer_labels
78
from sklearn.base import ClassifierMixin, BaseEstimator
89
from sklearn.utils import check_random_state
@@ -11,6 +12,8 @@
1112
from sklearn.exceptions import NotFittedError
1213
from sklearn.preprocessing import LabelEncoder
1314

15+
from mattspy.json import EstimatorToFromJSONMixin
16+
1417

1518
@jax.jit
1619
def _lowrank_twoway_term(x, vmat):
@@ -121,7 +124,11 @@ def _call_in_batches_maybe(self, func, X):
121124
return func(self.params_, X)
122125

123126

124-
class FMClassifier(ClassifierMixin, BaseEstimator):
127+
class _LabelEncoder(EstimatorToFromJSONMixin, LabelEncoder):
128+
json_attributes_ = ("classes_",)
129+
130+
131+
class FMClassifier(EstimatorToFromJSONMixin, ClassifierMixin, BaseEstimator):
125132
r"""A Factorization Machine classifier.
126133
127134
The FM model for the logits for class c is
@@ -138,7 +145,7 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
138145
batch_size : int, optional
139146
The number of examples to use when fitting the estimator
140147
and making predictions. The value None indicates to use all
141-
examples.
148+
examples. This parameter is ignored if the solver is set to `lbfgs`.
142149
lambda_v : float, optional
143150
The L2 regularization strength to use for the low-rank embedding
144151
matrix.
@@ -172,6 +179,18 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
172179
otherwise.
173180
"""
174181

182+
json_attributes_ = (
183+
"_is_fit",
184+
"_rng",
185+
"_jax_rng_key",
186+
"classes_",
187+
"n_classes_",
188+
"params_",
189+
"converged_",
190+
"n_iter_",
191+
"_label_encoder",
192+
)
193+
175194
def __init__(
176195
self,
177196
rank=8,
@@ -201,8 +220,44 @@ def __init__(
201220
self.backend = backend
202221

203222
def fit(self, X, y):
223+
"""Fit the FM to data `X` and `y`.
224+
225+
Parameters
226+
----------
227+
X : array-like
228+
An array of shape `(n_samples, n_features)`.
229+
y : array-like
230+
An array of labels of shape `(n_samples)`.
231+
232+
Returns
233+
-------
234+
self : object
235+
The fit estimator.
236+
"""
237+
204238
self._is_fit = False
205-
return self.partial_fit(X, y)
239+
return self._partial_fit(self.max_iter, X, y)
240+
241+
def partial_fit(self, X, y, classes=None):
242+
"""Fit the FM to data `X` and `y` for a single epoch.
243+
244+
Parameters
245+
----------
246+
X : array-like
247+
An array of shape `(n_samples, n_features)`.
248+
y : array-like
249+
An array of labels of shape `(n_samples)`.
250+
classes : array-like, optional
251+
If given, an optional array of unique class labels
252+
that is used instead of extracting them from the input
253+
`y`.
254+
255+
Returns
256+
-------
257+
self : object
258+
The fit estimator.
259+
"""
260+
return self._partial_fit(1, X, y, classes=classes)
206261

207262
def _init_numpy(self, X, y, classes=None):
208263
X, y = validate_data(self, X=X, y=y, reset=True)
@@ -215,9 +270,9 @@ def _init_numpy(self, X, y, classes=None):
215270
)
216271

217272
if classes is not None:
218-
self._label_encoder = LabelEncoder().fit(classes)
273+
self._label_encoder = _LabelEncoder().fit(classes)
219274
else:
220-
self._label_encoder = LabelEncoder().fit(y)
275+
self._label_encoder = _LabelEncoder().fit(y)
221276
self.classes_ = self._label_encoder.classes_
222277
self.n_classes_ = len(self.classes_)
223278
return X, y
@@ -229,7 +284,13 @@ def _init_jax(self, X, y, classes=None):
229284
else:
230285
self.classes_ = jnp.unique(y)
231286
self.n_classes_ = len(self.classes_)
232-
self.n_features_in_ = X.shape[1]
287+
288+
validate_data(
289+
self,
290+
X=np.ones((1, X.shape[1])),
291+
y=np.ones(1, dtype=np.int32),
292+
reset=True,
293+
)
233294

234295
if not jnp.array_equal(jnp.arange(self.n_classes_), self.classes_):
235296
raise ValueError(
@@ -239,28 +300,36 @@ def _init_jax(self, X, y, classes=None):
239300

240301
return X, y
241302

242-
def partial_fit(self, X, y, classes=None):
243-
if not getattr(self, "_is_fit", False):
244-
self._rng = check_random_state(self.random_state)
303+
def _init_from_json(self, X=None, y=None, classes=None, **kwargs):
304+
self.n_iter_ = kwargs.get("n_iter_", 0)
305+
self._rng = kwargs.get("_rng", check_random_state(self.random_state))
306+
if "_jax_rng_key" in kwargs:
307+
self._jax_rng_key = kwargs["_jax_rng_key"]
308+
else:
245309
self._jax_rng_key = jax.random.key(
246310
self._rng.randint(low=1, high=int(2**31))
247311
)
312+
self.converged_ = kwargs.get(
313+
"converged_",
314+
False,
315+
)
316+
self._is_fit = kwargs.get("_is_fit", True)
317+
318+
if X is None and y is None:
319+
# restore strictly from JSON
320+
if "classes_" in kwargs:
321+
self.classes_ = kwargs["classes_"]
322+
if "n_classes_" in kwargs:
323+
self.n_classes_ = kwargs["n_classes_"]
324+
if "_label_encoder" in kwargs:
325+
self._label_encoder = kwargs["_label_encoder"]
326+
else:
248327
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
249328
X, y = self._init_numpy(X, y, classes=classes)
250329
else:
251330
X, y = self._init_jax(X, y, classes=classes)
252-
else:
253-
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
254-
X, y = validate_data(self, X=X, y=y, reset=False)
255-
else:
256-
y = jnp.rint(y).astype(jnp.int32)
257-
258-
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
259-
y = self._label_encoder.transform(y)
260-
X = jnp.array(X)
261-
y = jnp.array(y)
262331

263-
if not getattr(self, "_is_fit", False):
332+
if "params_" not in kwargs:
264333
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
265334
w0 = jax.random.normal(subkey, shape=(self.n_classes_))
266335
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
@@ -269,52 +338,70 @@ def partial_fit(self, X, y, classes=None):
269338
vmat = jax.random.normal(
270339
subkey, shape=(self.n_features_in_, self.rank, self.n_classes_)
271340
)
272-
params = (w0, w, vmat)
341+
self.params_ = (w0, w, vmat)
273342
else:
274-
params = tuple(p.copy() for p in self.params_)
275-
276-
kwargs = {k: v for k, v in (self.solver_kwargs or tuple())}
277-
optimizer = getattr(optax, self.solver)(**kwargs)
278-
opt_state = optimizer.init(params)
279-
converged = False
343+
self.params_ = kwargs["params_"]
280344

281-
if self.batch_size is not None:
282-
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
283-
inds = jax.random.permutation(subkey, X.shape[0])
284-
for start in range(0, X.shape[0], self.batch_size):
285-
end = min(start + self.batch_size, X.shape[0])
286-
Xb = X[inds[start:end], :]
287-
yb = y[inds[start:end]]
288-
grads = _grad_jax_loss_func(
289-
params, Xb, yb, self.lambda_v, self.lambda_w
290-
)
291-
updates, opt_state = optimizer.update(grads, opt_state, params)
292-
params = optax.apply_updates(params, updates)
345+
return X, y
293346

294-
self.n_iter_ = 1
347+
def _partial_fit(self, n_epochs, X, y, classes=None):
348+
if not getattr(self, "_is_fit", False):
349+
X, y = self._init_from_json(X=X, y=y, classes=classes)
295350
else:
296-
new_value = None
297-
for i in range(self.max_iter):
298-
value = new_value
299-
300-
if self.solver in ["lbfgs"]:
301-
new_value, grads = _value_and_grad_from_state_jax_loss_func(
302-
params,
303-
X,
304-
y,
305-
self.lambda_v,
306-
self.lambda_w,
307-
state=opt_state,
308-
)
351+
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
352+
X, y = validate_data(self, X=X, y=y, reset=False)
353+
else:
354+
y = jnp.rint(y).astype(jnp.int32)
355+
356+
if not (isinstance(X, jnp.ndarray) and isinstance(y, jnp.ndarray)):
357+
y = self._label_encoder.transform(y)
358+
X = jnp.array(X)
359+
y = jnp.array(y)
360+
361+
kwargs = {k: v for k, v in (self.solver_kwargs or tuple())}
362+
optimizer = getattr(optax, self.solver)(**kwargs)
363+
opt_state = optimizer.init(self.params_)
364+
new_value = None
365+
366+
for _ in range(n_epochs):
367+
value = new_value
368+
369+
if self.solver not in ["lbfgs"]:
370+
if self.batch_size is not None:
371+
self._jax_rng_key, subkey = jax.random.split(self._jax_rng_key)
372+
inds = jax.random.permutation(subkey, X.shape[0])
373+
for start in range(0, X.shape[0], self.batch_size):
374+
end = min(start + self.batch_size, X.shape[0])
375+
Xb = X[inds[start:end], :]
376+
yb = y[inds[start:end]]
377+
grads = _grad_jax_loss_func(
378+
self.params_, Xb, yb, self.lambda_v, self.lambda_w
379+
)
380+
updates, opt_state = optimizer.update(
381+
grads, opt_state, self.params_
382+
)
383+
new_params = optax.apply_updates(self.params_, updates)
309384
else:
310-
new_value, grads = _value_and_grad_jax_loss_func(
311-
params, X, y, self.lambda_v, self.lambda_w
385+
grads = _grad_jax_loss_func(
386+
self.params_, X, y, self.lambda_v, self.lambda_w
312387
)
313-
388+
updates, opt_state = optimizer.update(
389+
grads, opt_state, self.params_
390+
)
391+
new_params = optax.apply_updates(self.params_, updates)
392+
else:
393+
new_value, grads = _value_and_grad_from_state_jax_loss_func(
394+
self.params_,
395+
X,
396+
y,
397+
self.lambda_v,
398+
self.lambda_w,
399+
state=opt_state,
400+
)
314401
updates, opt_state = optimizer.update(
315402
grads,
316403
opt_state,
317-
params,
404+
self.params_,
318405
value=new_value,
319406
grad=grads,
320407
value_fn=partial(
@@ -325,30 +412,43 @@ def partial_fit(self, X, y, classes=None):
325412
lambda_w=self.lambda_w,
326413
),
327414
)
328-
new_params = optax.apply_updates(params, updates)
329-
330-
if i > 0 and (
331-
all(
332-
[
333-
jnp.allclose(new_p, p, atol=self.atol, rtol=self.rtol)
334-
for new_p, p in zip(new_params, params)
335-
]
336-
)
337-
or jnp.allclose(value, new_value, atol=self.atol, rtol=self.rtol)
338-
):
339-
converged = True
340-
break
341-
params = new_params
342-
343-
self.n_iter_ = i
415+
new_params = optax.apply_updates(self.params_, updates)
416+
417+
self.n_iter_ += 1
418+
if self.n_iter_ > 1 and (
419+
all(
420+
[
421+
jnp.allclose(new_p, p, atol=self.atol, rtol=self.rtol)
422+
for new_p, p in zip(new_params, self.params_)
423+
]
424+
)
425+
or (
426+
self.solver in ["lbfgs"]
427+
and jnp.allclose(value, new_value, atol=self.atol, rtol=self.rtol)
428+
)
429+
):
430+
self.converged_ = True
431+
break
344432

345-
self.params_ = params
346-
self.converged_ = converged
347-
self._is_fit = True
433+
self.params_ = new_params
348434

349435
return self
350436

351437
def predict_log_proba(self, X):
438+
"""Predict the log-probability of each class for data `X`.
439+
440+
Parameters
441+
----------
442+
X : array-like
443+
An array of shape `(n_samples, n_features)`.
444+
445+
Returns
446+
-------
447+
log_proba : array-like
448+
An array of labels of shape `(n_samples, n_classes_)` if `n_classes_` > 2,
449+
else `(n_samples)`.
450+
"""
451+
352452
if not isinstance(X, jnp.ndarray):
353453
X = validate_data(self, X=X, reset=False)
354454
if not getattr(self, "_is_fit", False):
@@ -358,6 +458,20 @@ def predict_log_proba(self, X):
358458
return _call_in_batches_maybe(self, _jax_log_proba, X)
359459

360460
def predict_proba(self, X):
461+
"""Predict the probability of each class for data `X`.
462+
463+
Parameters
464+
----------
465+
X : array-like
466+
An array of shape `(n_samples, n_features)`.
467+
468+
Returns
469+
-------
470+
proba : array-like
471+
An array of labels of shape `(n_samples, n_classes_)` if `n_classes_` > 2,
472+
else `(n_samples)`.
473+
"""
474+
361475
if not isinstance(X, jnp.ndarray):
362476
X = validate_data(self, X=X, reset=False)
363477
if not getattr(self, "_is_fit", False):
@@ -367,6 +481,19 @@ def predict_proba(self, X):
367481
return _call_in_batches_maybe(self, _jax_proba, X)
368482

369483
def predict(self, X):
484+
"""Predict the class for data `X`.
485+
486+
Parameters
487+
----------
488+
X : array-like
489+
An array of shape `(n_samples, n_features)`.
490+
491+
Returns
492+
-------
493+
y : array-like
494+
An array of labels of shape `(n_samples)`.
495+
"""
496+
370497
if not isinstance(X, jnp.ndarray):
371498
X = validate_data(self, X=X, reset=False)
372499
if not getattr(self, "_is_fit", False):

0 commit comments

Comments
 (0)