Skip to content

Commit ac9cc12

Browse files
authored
Merge pull request #240 from smartIU/master
Reset distance function in __setstate__
2 parents 22a21bd + 1c0d2e9 commit ac9cc12

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

pynndescent/pynndescent_.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -751,32 +751,9 @@ def __init__(
751751
current_random_state = check_random_state(self.random_state)
752752

753753
self._distance_correction = None
754-
755-
if callable(metric):
756-
_distance_func = metric
757-
elif metric in pynnd_dist.named_distances:
758-
if metric in pynnd_dist.fast_distance_alternatives:
759-
_distance_func = pynnd_dist.fast_distance_alternatives[metric]["dist"]
760-
self._distance_correction = pynnd_dist.fast_distance_alternatives[
761-
metric
762-
]["correction"]
763-
else:
764-
_distance_func = pynnd_dist.named_distances[metric]
765-
else:
766-
raise ValueError("Metric is neither callable, " + "nor a recognised string")
767-
768-
# Create a partial function for distances with arguments
769-
if len(self._dist_args) > 0:
770-
dist_args = self._dist_args
771-
772-
@numba.njit()
773-
def _partial_dist_func(x, y):
774-
return _distance_func(x, y, *dist_args)
775-
776-
self._distance_func = _partial_dist_func
777-
else:
778-
self._distance_func = _distance_func
779-
754+
755+
self._set_distance_func()
756+
780757
if metric in (
781758
"cosine",
782759
"dot",
@@ -967,6 +944,32 @@ def _partial_dist_func(ind1, data1, ind2, data2):
967944

968945
numba.set_num_threads(self._original_num_threads)
969946

947+
def _set_distance_func(self):
948+
if callable(self.metric):
949+
_distance_func = self.metric
950+
elif self.metric in pynnd_dist.named_distances:
951+
if self.metric in pynnd_dist.fast_distance_alternatives:
952+
_distance_func = pynnd_dist.fast_distance_alternatives[self.metric]["dist"]
953+
self._distance_correction = pynnd_dist.fast_distance_alternatives[
954+
self.metric
955+
]["correction"]
956+
else:
957+
_distance_func = pynnd_dist.named_distances[self.metric]
958+
else:
959+
raise ValueError("Metric is neither callable, " + "nor a recognised string")
960+
961+
# Create a partial function for distances with arguments
962+
if len(self._dist_args) > 0:
963+
dist_args = self._dist_args
964+
965+
@numba.njit()
966+
def _partial_dist_func(x, y):
967+
return _distance_func(x, y, *dist_args)
968+
969+
self._distance_func = _partial_dist_func
970+
else:
971+
self._distance_func = _distance_func
972+
970973
def __getstate__(self):
971974
if not hasattr(self, "_search_graph"):
972975
self._init_search_graph()
@@ -985,6 +988,7 @@ def __getstate__(self):
985988

986989
def __setstate__(self, d):
987990
self.__dict__ = d
991+
self._set_distance_func()
988992
self._search_forest = tuple(
989993
[renumbaify_tree(tree) for tree in d["_search_forest"]]
990994
)

0 commit comments

Comments
 (0)