33from jax import numpy as jnp
44from jax .tree_util import Partial as partial
55
6+ import numpy as np
67from optax .losses import softmax_cross_entropy_with_integer_labels
78from sklearn .base import ClassifierMixin , BaseEstimator
89from sklearn .utils import check_random_state
1112from sklearn .exceptions import NotFittedError
1213from sklearn .preprocessing import LabelEncoder
1314
15+ from mattspy .json import EstimatorToFromJSONMixin
16+
1417
1518@jax .jit
1619def _lowrank_twoway_term (x , vmat ):
@@ -121,7 +124,11 @@ def _call_in_batches_maybe(self, func, X):
121124 return func (self .params_ , X )
122125
123126
124- class FMClassifier (ClassifierMixin , BaseEstimator ):
127+ class _LabelEncoder (EstimatorToFromJSONMixin , LabelEncoder ):
128+ json_attributes_ = ("classes_" ,)
129+
130+
131+ class FMClassifier (EstimatorToFromJSONMixin , ClassifierMixin , BaseEstimator ):
125132 r"""A Factorization Machine classifier.
126133
127134 The FM model for the logits for class c is
@@ -138,7 +145,7 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
138145 batch_size : int, optional
139146 The number of examples to use when fitting the estimator
140147 and making predictions. The value None indicates to use all
141- examples.
148+ examples. This parameter is ignored if the solver is set to `lbfgs`.
142149 lambda_v : float, optional
143150 The L2 regularization strength to use for the low-rank embedding
144151 matrix.
@@ -172,6 +179,18 @@ class FMClassifier(ClassifierMixin, BaseEstimator):
172179 otherwise.
173180 """
174181
182+ json_attributes_ = (
183+ "_is_fit" ,
184+ "_rng" ,
185+ "_jax_rng_key" ,
186+ "classes_" ,
187+ "n_classes_" ,
188+ "params_" ,
189+ "converged_" ,
190+ "n_iter_" ,
191+ "_label_encoder" ,
192+ )
193+
175194 def __init__ (
176195 self ,
177196 rank = 8 ,
@@ -202,7 +221,10 @@ def __init__(
202221
203222 def fit (self , X , y ):
204223 self ._is_fit = False
205- return self .partial_fit (X , y )
224+ return self ._partial_fit (self .max_iter , X , y )
225+
226+ def partial_fit (self , X , y , classes = None ):
227+ return self ._partial_fit (1 , X , y , classes = classes )
206228
207229 def _init_numpy (self , X , y , classes = None ):
208230 X , y = validate_data (self , X = X , y = y , reset = True )
@@ -215,9 +237,9 @@ def _init_numpy(self, X, y, classes=None):
215237 )
216238
217239 if classes is not None :
218- self ._label_encoder = LabelEncoder ().fit (classes )
240+ self ._label_encoder = _LabelEncoder ().fit (classes )
219241 else :
220- self ._label_encoder = LabelEncoder ().fit (y )
242+ self ._label_encoder = _LabelEncoder ().fit (y )
221243 self .classes_ = self ._label_encoder .classes_
222244 self .n_classes_ = len (self .classes_ )
223245 return X , y
@@ -229,7 +251,13 @@ def _init_jax(self, X, y, classes=None):
229251 else :
230252 self .classes_ = jnp .unique (y )
231253 self .n_classes_ = len (self .classes_ )
232- self .n_features_in_ = X .shape [1 ]
254+
255+ validate_data (
256+ self ,
257+ X = np .ones ((1 , X .shape [1 ])),
258+ y = np .ones (1 , dtype = np .int32 ),
259+ reset = True ,
260+ )
233261
234262 if not jnp .array_equal (jnp .arange (self .n_classes_ ), self .classes_ ):
235263 raise ValueError (
@@ -239,28 +267,36 @@ def _init_jax(self, X, y, classes=None):
239267
240268 return X , y
241269
242- def partial_fit (self , X , y , classes = None ):
243- if not getattr (self , "_is_fit" , False ):
244- self ._rng = check_random_state (self .random_state )
270+ def _init_from_json (self , X = None , y = None , classes = None , ** kwargs ):
271+ self .n_iter_ = kwargs .get ("n_iter_" , 0 )
272+ self ._rng = kwargs .get ("_rng" , check_random_state (self .random_state ))
273+ if "_jax_rng_key" in kwargs :
274+ self ._jax_rng_key = kwargs ["_jax_rng_key" ]
275+ else :
245276 self ._jax_rng_key = jax .random .key (
246277 self ._rng .randint (low = 1 , high = int (2 ** 31 ))
247278 )
279+ self .converged_ = kwargs .get (
280+ "converged_" ,
281+ False ,
282+ )
283+ self ._is_fit = kwargs .get ("_is_fit" , True )
284+
285+ if X is None and y is None :
286+ # restore strictly from JSON
287+ if "classes_" in kwargs :
288+ self .classes_ = kwargs ["classes_" ]
289+ if "n_classes_" in kwargs :
290+ self .n_classes_ = kwargs ["n_classes_" ]
291+ if "_label_encoder" in kwargs :
292+ self ._label_encoder = kwargs ["_label_encoder" ]
293+ else :
248294 if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
249295 X , y = self ._init_numpy (X , y , classes = classes )
250296 else :
251297 X , y = self ._init_jax (X , y , classes = classes )
252- else :
253- if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
254- X , y = validate_data (self , X = X , y = y , reset = False )
255- else :
256- y = jnp .rint (y ).astype (jnp .int32 )
257298
258- if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
259- y = self ._label_encoder .transform (y )
260- X = jnp .array (X )
261- y = jnp .array (y )
262-
263- if not getattr (self , "_is_fit" , False ):
299+ if "params_" not in kwargs :
264300 self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
265301 w0 = jax .random .normal (subkey , shape = (self .n_classes_ ))
266302 self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
@@ -269,52 +305,70 @@ def partial_fit(self, X, y, classes=None):
269305 vmat = jax .random .normal (
270306 subkey , shape = (self .n_features_in_ , self .rank , self .n_classes_ )
271307 )
272- params = (w0 , w , vmat )
308+ self . params_ = (w0 , w , vmat )
273309 else :
274- params = tuple ( p . copy () for p in self . params_ )
310+ self . params_ = kwargs [ " params_" ]
275311
276- kwargs = {k : v for k , v in (self .solver_kwargs or tuple ())}
277- optimizer = getattr (optax , self .solver )(** kwargs )
278- opt_state = optimizer .init (params )
279- converged = False
280-
281- if self .batch_size is not None :
282- self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
283- inds = jax .random .permutation (subkey , X .shape [0 ])
284- for start in range (0 , X .shape [0 ], self .batch_size ):
285- end = min (start + self .batch_size , X .shape [0 ])
286- Xb = X [inds [start :end ], :]
287- yb = y [inds [start :end ]]
288- grads = _grad_jax_loss_func (
289- params , Xb , yb , self .lambda_v , self .lambda_w
290- )
291- updates , opt_state = optimizer .update (grads , opt_state , params )
292- params = optax .apply_updates (params , updates )
312+ return X , y
293313
294- self .n_iter_ = 1
314+ def _partial_fit (self , n_epochs , X , y , classes = None ):
315+ if not getattr (self , "_is_fit" , False ):
316+ X , y = self ._init_from_json (X = X , y = y , classes = classes )
295317 else :
296- new_value = None
297- for i in range (self .max_iter ):
298- value = new_value
299-
300- if self .solver in ["lbfgs" ]:
301- new_value , grads = _value_and_grad_from_state_jax_loss_func (
302- params ,
303- X ,
304- y ,
305- self .lambda_v ,
306- self .lambda_w ,
307- state = opt_state ,
308- )
318+ if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
319+ X , y = validate_data (self , X = X , y = y , reset = False )
320+ else :
321+ y = jnp .rint (y ).astype (jnp .int32 )
322+
323+ if not (isinstance (X , jnp .ndarray ) and isinstance (y , jnp .ndarray )):
324+ y = self ._label_encoder .transform (y )
325+ X = jnp .array (X )
326+ y = jnp .array (y )
327+
328+ kwargs = {k : v for k , v in (self .solver_kwargs or tuple ())}
329+ optimizer = getattr (optax , self .solver )(** kwargs )
330+ opt_state = optimizer .init (self .params_ )
331+ new_value = None
332+
333+ for _ in range (n_epochs ):
334+ value = new_value
335+
336+ if self .solver not in ["lbfgs" ]:
337+ if self .batch_size is not None :
338+ self ._jax_rng_key , subkey = jax .random .split (self ._jax_rng_key )
339+ inds = jax .random .permutation (subkey , X .shape [0 ])
340+ for start in range (0 , X .shape [0 ], self .batch_size ):
341+ end = min (start + self .batch_size , X .shape [0 ])
342+ Xb = X [inds [start :end ], :]
343+ yb = y [inds [start :end ]]
344+ grads = _grad_jax_loss_func (
345+ self .params_ , Xb , yb , self .lambda_v , self .lambda_w
346+ )
347+ updates , opt_state = optimizer .update (
348+ grads , opt_state , self .params_
349+ )
350+ new_params = optax .apply_updates (self .params_ , updates )
309351 else :
310- new_value , grads = _value_and_grad_jax_loss_func (
311- params , X , y , self .lambda_v , self .lambda_w
352+ grads = _grad_jax_loss_func (
353+ self . params_ , X , y , self .lambda_v , self .lambda_w
312354 )
313-
355+ updates , opt_state = optimizer .update (
356+ grads , opt_state , self .params_
357+ )
358+ new_params = optax .apply_updates (self .params_ , updates )
359+ else :
360+ new_value , grads = _value_and_grad_from_state_jax_loss_func (
361+ self .params_ ,
362+ X ,
363+ y ,
364+ self .lambda_v ,
365+ self .lambda_w ,
366+ state = opt_state ,
367+ )
314368 updates , opt_state = optimizer .update (
315369 grads ,
316370 opt_state ,
317- params ,
371+ self . params_ ,
318372 value = new_value ,
319373 grad = grads ,
320374 value_fn = partial (
@@ -325,26 +379,25 @@ def partial_fit(self, X, y, classes=None):
325379 lambda_w = self .lambda_w ,
326380 ),
327381 )
328- new_params = optax .apply_updates (params , updates )
329-
330- if i > 0 and (
331- all (
332- [
333- jnp .allclose (new_p , p , atol = self .atol , rtol = self .rtol )
334- for new_p , p in zip (new_params , params )
335- ]
336- )
337- or jnp .allclose (value , new_value , atol = self .atol , rtol = self .rtol )
338- ):
339- converged = True
340- break
341- params = new_params
342-
343- self .n_iter_ = i
382+ new_params = optax .apply_updates (self .params_ , updates )
383+
384+ self .n_iter_ += 1
385+ if self .n_iter_ > 1 and (
386+ all (
387+ [
388+ jnp .allclose (new_p , p , atol = self .atol , rtol = self .rtol )
389+ for new_p , p in zip (new_params , self .params_ )
390+ ]
391+ )
392+ or (
393+ self .solver in ["lbfgs" ]
394+ and jnp .allclose (value , new_value , atol = self .atol , rtol = self .rtol )
395+ )
396+ ):
397+ self .converged_ = True
398+ break
344399
345- self .params_ = params
346- self .converged_ = converged
347- self ._is_fit = True
400+ self .params_ = new_params
348401
349402 return self
350403
0 commit comments