33from jax import numpy as jnp
44from jax .tree_util import Partial as partial
55
6+ import numpy as np
67from optax .losses import softmax_cross_entropy_with_integer_labels
78from sklearn .base import ClassifierMixin , BaseEstimator
89from sklearn .utils import check_random_state
1112from sklearn .exceptions import NotFittedError
1213from sklearn .preprocessing import LabelEncoder
1314
15+ from mattspy .json import EstimatorToFromJSONMixin
16+
1417
1518@jax .jit
1619def _lowrank_twoway_term (x , vmat ):
@@ -121,7 +124,11 @@ def _call_in_batches_maybe(self, func, X):
121124 return func (self .params_ , X )
122125
123126
124- class FMClassifier (ClassifierMixin , BaseEstimator ):
127+ class _LabelEncoder (EstimatorToFromJSONMixin , LabelEncoder ):
128+ json_attributes_ = ("classes_" ,)
129+
130+
131+ class FMClassifier (EstimatorToFromJSONMixin , ClassifierMixin , BaseEstimator ):
125132 r"""A Factorization Machine classifier.
126133
127134 The FM model for the logits for class c is
@@ -138,7 +145,7 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
138145 batch_size : int, optional
139146 The number of examples to use when fitting the estimator
140147 and making predictions. The value None indicates to use all
141- examples.
148+ examples. This parameter is ignored if the solver is set to `lbfgs`.
142149 lambda_v : float, optional
143150 The L2 regularization strength to use for the low-rank embedding
144151 matrix.
@@ -172,6 +179,18 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
172179 otherwise.
173180 """
174181
182+ json_attributes_ = (
183+ "_is_fit" ,
184+ "_rng" ,
185+ "_jax_rng_key" ,
186+ "classes_" ,
187+ "n_classes_" ,
188+ "params_" ,
189+ "converged_" ,
190+ "n_iter_" ,
191+ "_label_encoder" ,
192+ )
193+
175194 def __init__ (
176195 self ,
177196 rank = 8 ,
@@ -201,8 +220,44 @@ def __init__(
201220 self .backend = backend
202221
203222 def fit (self , X , y ):
223+ """Fit the FM to data `X` and `y`.
224+
225+ Parameters
226+ ----------
227+ X : array-like
228+ An array of shape `(n_samples, n_features)`.
229+ y : array-like
230+ An array of labels of shape `(n_samples)`.
231+
232+ Returns
233+ -------
234+ self : object
235+ The fit estimator.
236+ """
237+
204238 self ._is_fit = False
205- return self .partial_fit (X , y )
239+ return self ._partial_fit (self .max_iter , X , y )
240+
241+ def partial_fit (self , X , y , classes = None ):
242+ """Fit the FM to data `X` and `y` for a single epoch.
243+
244+ Parameters
245+ ----------
246+ X : array-like
247+ An array of shape `(n_samples, n_features)`.
248+ y : array-like
249+ An array of labels of shape `(n_samples)`.
250+ classes : array-like, optional
251+ If given, an optional array of unique class labels
252+ that is used instead of extracting them from the input
253+ `y`.
254+
255+ Returns
256+ -------
257+ self : object
258+ The fit estimator.
259+ """
260+ return self ._partial_fit (1 , X , y , classes = classes )
206261
207262 def _init_numpy (self , X , y , classes = None ):
208263 X , y = validate_data (self , X = X , y = y , reset = True )
@@ -215,9 +270,9 @@ def _init_numpy(self, X, y, classes=None):
215270 )
216271
217272 if classes is not None :
218- self ._label_encoder = LabelEncoder ().fit (classes )
273+ self ._label_encoder = _LabelEncoder ().fit (classes )
219274 else :
220- self ._label_encoder = LabelEncoder ().fit (y )
275+ self ._label_encoder = _LabelEncoder ().fit (y )
221276 self .classes_ = self ._label_encoder .classes_
222277 self .n_classes_ = len (self .classes_ )
223278 return X , y
@@ -229,7 +284,13 @@ def _init_jax(self, X, y, classes=None):
229284 else :
230285 self .classes_ = jnp .unique (y )
231286 self .n_classes_ = len (self .classes_ )
232- self .n_features_in_ = X .shape [1 ]
287+
288+ validate_data (
289+ self ,
290+ X = np .ones ((1 , X .shape [1 ])),
291+ y = np .ones (1 , dtype = np .int32 ),
292+ reset = True ,
293+ )
233294
234295 if not jnp .array_equal (jnp .arange (self .n_classes_ ), self .classes_ ):
235296 raise ValueError (
@@ -239,28 +300,36 @@ def _init_jax(self, X, y, classes=None):
239300
240301 return X , y
241302
242- def partial_fit (self , X , y , classes = None ):
243- if not getattr (self , "_is_fit" , False ):
244- self ._rng = check_random_state (self .random_state )
303+ def _init_from_json (self , X = None , y = None , classes = None , ** kwargs ):
304+ self .n_iter_ = kwargs .get ("n_iter_" , 0 )
305+ self ._rng = kwargs .get ("_rng" , check_random_state (self .random_state ))
306+ if "_jax_rng_key" in kwargs :
307+ self ._jax_rng_key = kwargs ["_jax_rng_key" ]
308+ else :
245309 self ._jax_rng_key = jax .random .key (
246310 self ._rng .randint (low = 1 , high = int (2 ** 31 ))
247311 )
312+ self .converged_ = kwargs .get (
313+ "converged_" ,
314+ False ,
315+ )
316+ self ._is_fit = kwargs .get ("_is_fit" , True )
317+
318+ if X is None and y is None :
319+ # restore strictly from JSON
320+ if "classes_" in kwargs :
321+ self .classes_ = kwargs ["classes_" ]
322+ if "n_classes_" in kwargs :
323+ self .n_classes_ = kwargs ["n_classes_" ]
324+ if "_label_encoder" in kwargs :
325+ self ._label_encoder = kwargs ["_label_encoder" ]
326+ else :
248327 if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
249328 X , y = self ._init_numpy (X , y , classes = classes )
250329 else :
251330 X , y = self ._init_jax (X , y , classes = classes )
252- else :
253- if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
254- X , y = validate_data (self , X = X , y = y , reset = False )
255- else :
256- y = jnp .rint (y ).astype (jnp .int32 )
257-
258- if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
259- y = self ._label_encoder .transform (y )
260- X = jnp .array (X )
261- y = jnp .array (y )
262331
263- if not getattr ( self , "_is_fit" , False ) :
332+ if "params_" not in kwargs :
264333 self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
265334 w0 = jax .random .normal (subkey , shape = (self .n_classes_ ))
266335 self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
@@ -269,52 +338,70 @@ def partial_fit(self, X, y, classes=None):
269338 vmat = jax .random .normal (
270339 subkey , shape = (self .n_features_in_ , self .rank , self .n_classes_ )
271340 )
272- params = (w0 , w , vmat )
341+ self . params_ = (w0 , w , vmat )
273342 else :
274- params = tuple (p .copy () for p in self .params_ )
275-
276- kwargs = {k : v for k , v in (self .solver_kwargs or tuple ())}
277- optimizer = getattr (optax , self .solver )(** kwargs )
278- opt_state = optimizer .init (params )
279- converged = False
343+ self .params_ = kwargs ["params_" ]
280344
281- if self .batch_size is not None :
282- self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
283- inds = jax .random .permutation (subkey , X .shape [0 ])
284- for start in range (0 , X .shape [0 ], self .batch_size ):
285- end = min (start + self .batch_size , X .shape [0 ])
286- Xb = X [inds [start :end ], :]
287- yb = y [inds [start :end ]]
288- grads = _grad_jax_loss_func (
289- params , Xb , yb , self .lambda_v , self .lambda_w
290- )
291- updates , opt_state = optimizer .update (grads , opt_state , params )
292- params = optax .apply_updates (params , updates )
345+ return X , y
293346
294- self .n_iter_ = 1
347+ def _partial_fit (self , n_epochs , X , y , classes = None ):
348+ if not getattr (self , "_is_fit" , False ):
349+ X , y = self ._init_from_json (X = X , y = y , classes = classes )
295350 else :
296- new_value = None
297- for i in range (self .max_iter ):
298- value = new_value
299-
300- if self .solver in ["lbfgs" ]:
301- new_value , grads = _value_and_grad_from_state_jax_loss_func (
302- params ,
303- X ,
304- y ,
305- self .lambda_v ,
306- self .lambda_w ,
307- state = opt_state ,
308- )
351+ if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
352+ X , y = validate_data (self , X = X , y = y , reset = False )
353+ else :
354+ y = jnp .rint (y ).astype (jnp .int32 )
355+
356+ if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
357+ y = self ._label_encoder .transform (y )
358+ X = jnp .array (X )
359+ y = jnp .array (y )
360+
361+ kwargs = {k : v for k , v in (self .solver_kwargs or tuple ())}
362+ optimizer = getattr (optax , self .solver )(** kwargs )
363+ opt_state = optimizer .init (self .params_ )
364+ new_value = None
365+
366+ for _ in range (n_epochs ):
367+ value = new_value
368+
369+ if self .solver not in ["lbfgs" ]:
370+ if self .batch_size is not None :
371+ self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
372+ inds = jax .random .permutation (subkey , X .shape [0 ])
373+ for start in range (0 , X .shape [0 ], self .batch_size ):
374+ end = min (start + self .batch_size , X .shape [0 ])
375+ Xb = X [inds [start :end ], :]
376+ yb = y [inds [start :end ]]
377+ grads = _grad_jax_loss_func (
378+ self .params_ , Xb , yb , self .lambda_v , self .lambda_w
379+ )
380+ updates , opt_state = optimizer .update (
381+ grads , opt_state , self .params_
382+ )
383+ new_params = optax .apply_updates (self .params_ , updates )
309384 else :
310- new_value , grads = _value_and_grad_jax_loss_func (
311- params , X , y , self .lambda_v , self .lambda_w
385+ grads = _grad_jax_loss_func (
386+ self . params_ , X , y , self .lambda_v , self .lambda_w
312387 )
313-
388+ updates , opt_state = optimizer .update (
389+ grads , opt_state , self .params_
390+ )
391+ new_params = optax .apply_updates (self .params_ , updates )
392+ else :
393+ new_value , grads = _value_and_grad_from_state_jax_loss_func (
394+ self .params_ ,
395+ X ,
396+ y ,
397+ self .lambda_v ,
398+ self .lambda_w ,
399+ state = opt_state ,
400+ )
314401 updates , opt_state = optimizer .update (
315402 grads ,
316403 opt_state ,
317- params ,
404+ self . params_ ,
318405 value = new_value ,
319406 grad = grads ,
320407 value_fn = partial (
@@ -325,30 +412,43 @@ def partial_fit(self, X, y, classes=None):
325412 lambda_w = self .lambda_w ,
326413 ),
327414 )
328- new_params = optax .apply_updates (params , updates )
329-
330- if i > 0 and (
331- all (
332- [
333- jnp .allclose (new_p , p , atol = self .atol , rtol = self .rtol )
334- for new_p , p in zip (new_params , params )
335- ]
336- )
337- or jnp .allclose (value , new_value , atol = self .atol , rtol = self .rtol )
338- ):
339- converged = True
340- break
341- params = new_params
342-
343- self .n_iter_ = i
415+ new_params = optax .apply_updates (self .params_ , updates )
416+
417+ self .n_iter_ += 1
418+ if self .n_iter_ > 1 and (
419+ all (
420+ [
421+ jnp .allclose (new_p , p , atol = self .atol , rtol = self .rtol )
422+ for new_p , p in zip (new_params , self .params_ )
423+ ]
424+ )
425+ or (
426+ self .solver in ["lbfgs" ]
427+ and jnp .allclose (value , new_value , atol = self .atol , rtol = self .rtol )
428+ )
429+ ):
430+ self .converged_ = True
431+ break
344432
345- self .params_ = params
346- self .converged_ = converged
347- self ._is_fit = True
433+ self .params_ = new_params
348434
349435 return self
350436
351437 def predict_log_proba (self , X ):
438+ """Predict the log-probability of each class for data `X`.
439+
440+ Parameters
441+ ----------
442+ X : array-like
443+ An array of shape `(n_samples, n_features)`.
444+
445+ Returns
446+ -------
447+ log_proba : array-like
448+ An array of labels of shape `(n_samples, n_classes_)` if `n_classes_` > 2,
449+ else `(n_samples)`.
450+ """
451+
352452 if not isinstance (X , jnp .ndarray ):
353453 X = validate_data (self , X = X , reset = False )
354454 if not getattr (self , "_is_fit" , False ):
@@ -358,6 +458,20 @@ def predict_log_proba(self, X):
358458 return _call_in_batches_maybe (self , _jax_log_proba , X )
359459
360460 def predict_proba (self , X ):
461+ """Predict the probability of each class for data `X`.
462+
463+ Parameters
464+ ----------
465+ X : array-like
466+ An array of shape `(n_samples, n_features)`.
467+
468+ Returns
469+ -------
470+ proba : array-like
471+ An array of labels of shape `(n_samples, n_classes_)` if `n_classes_` > 2,
472+ else `(n_samples)`.
473+ """
474+
361475 if not isinstance (X , jnp .ndarray ):
362476 X = validate_data (self , X = X , reset = False )
363477 if not getattr (self , "_is_fit" , False ):
@@ -367,6 +481,19 @@ def predict_proba(self, X):
367481 return _call_in_batches_maybe (self , _jax_proba , X )
368482
369483 def predict (self , X ):
484+ """Predict the class for data `X`.
485+
486+ Parameters
487+ ----------
488+ X : array-like
489+ An array of shape `(n_samples, n_features)`.
490+
491+ Returns
492+ -------
493+ y : array-like
494+ An array of labels of shape `(n_samples)`.
495+ """
496+
370497 if not isinstance (X , jnp .ndarray ):
371498 X = validate_data (self , X = X , reset = False )
372499 if not getattr (self , "_is_fit" , False ):
0 commit comments