@@ -751,32 +751,9 @@ def __init__(
751751 current_random_state = check_random_state (self .random_state )
752752
753753 self ._distance_correction = None
754-
755- if callable (metric ):
756- _distance_func = metric
757- elif metric in pynnd_dist .named_distances :
758- if metric in pynnd_dist .fast_distance_alternatives :
759- _distance_func = pynnd_dist .fast_distance_alternatives [metric ]["dist" ]
760- self ._distance_correction = pynnd_dist .fast_distance_alternatives [
761- metric
762- ]["correction" ]
763- else :
764- _distance_func = pynnd_dist .named_distances [metric ]
765- else :
766- raise ValueError ("Metric is neither callable, " + "nor a recognised string" )
767-
768- # Create a partial function for distances with arguments
769- if len (self ._dist_args ) > 0 :
770- dist_args = self ._dist_args
771-
772- @numba .njit ()
773- def _partial_dist_func (x , y ):
774- return _distance_func (x , y , * dist_args )
775-
776- self ._distance_func = _partial_dist_func
777- else :
778- self ._distance_func = _distance_func
779-
754+
755+ self ._set_distance_func ()
756+
780757 if metric in (
781758 "cosine" ,
782759 "dot" ,
@@ -967,6 +944,32 @@ def _partial_dist_func(ind1, data1, ind2, data2):
967944
968945 numba .set_num_threads (self ._original_num_threads )
969946
947+ def _set_distance_func (self ):
948+ if callable (self .metric ):
949+ _distance_func = self .metric
950+ elif self .metric in pynnd_dist .named_distances :
951+ if self .metric in pynnd_dist .fast_distance_alternatives :
952+ _distance_func = pynnd_dist .fast_distance_alternatives [self .metric ]["dist" ]
953+ self ._distance_correction = pynnd_dist .fast_distance_alternatives [
954+ self .metric
955+ ]["correction" ]
956+ else :
957+ _distance_func = pynnd_dist .named_distances [self .metric ]
958+ else :
959+ raise ValueError ("Metric is neither callable, " + "nor a recognised string" )
960+
961+ # Create a partial function for distances with arguments
962+ if len (self ._dist_args ) > 0 :
963+ dist_args = self ._dist_args
964+
965+ @numba .njit ()
966+ def _partial_dist_func (x , y ):
967+ return _distance_func (x , y , * dist_args )
968+
969+ self ._distance_func = _partial_dist_func
970+ else :
971+ self ._distance_func = _distance_func
972+
970973 def __getstate__ (self ):
971974 if not hasattr (self , "_search_graph" ):
972975 self ._init_search_graph ()
@@ -985,6 +988,7 @@ def __getstate__(self):
985988
986989 def __setstate__ (self , d ):
987990 self .__dict__ = d
991+ self ._set_distance_func ()
988992 self ._search_forest = tuple (
989993 [renumbaify_tree (tree ) for tree in d ["_search_forest" ]]
990994 )
0 commit comments