-
Notifications
You must be signed in to change notification settings - Fork 229
[MRG] FIX Fix LMNN rollback #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
1f5b74e
a18fa13
5fb5ec2
8f5f62d
2a8b8b5
ec7e497
6cc3984
754de7a
dbb1ec8
9a61bc7
8b3f584
4e88dcf
a60a1b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,12 +95,16 @@ def fit(self, X, y): | |
L = self.L_ | ||
objective = np.inf | ||
|
||
# we initialize the roll back | ||
L_old = L.copy() | ||
G_old = G.copy() | ||
df_old = df.copy() | ||
a1_old = [a.copy() for a in a1] | ||
a2_old = [a.copy() for a in a2] | ||
objective_old = objective | ||
|
||
# main loop | ||
for it in xrange(1, self.max_iter): | ||
df_old = df.copy() | ||
a1_old = [a.copy() for a in a1] | ||
a2_old = [a.copy() for a in a2] | ||
objective_old = objective | ||
# Compute pairwise distances under current metric | ||
Lx = L.dot(self.X_.T).T | ||
g0 = _inplace_paired_L2(*Lx[impostors]) | ||
|
@@ -158,14 +162,25 @@ def fit(self, X, y): | |
if delta_obj > 0: | ||
# we're getting worse... roll back! | ||
learn_rate /= 2.0 | ||
L = L_old | ||
G = G_old | ||
df = df_old | ||
a1 = a1_old | ||
a2 = a2_old | ||
objective = objective_old | ||
else: | ||
# update L | ||
L -= learn_rate * 2 * L.dot(G) | ||
learn_rate *= 1.01 | ||
# We did good. We store this point as reference in case we do | ||
# worse next time. | ||
objective_old = objective | ||
L_old = L.copy() | ||
G_old = G.copy() | ||
df_old = df.copy() | ||
a1_old = [a.copy() for a in a1] | ||
a2_old = [a.copy() for a in a2] | ||
# we update L and will see in the next iteration if it does indeed | ||
# better | ||
L -= learn_rate * 2 * L.dot(G) | ||
learn_rate *= 1.01 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems wrong to me, though it was also wrong before this PR. Reading through the Shogun implementation here and here, they don't do any rolling back. In one of the reference Matlab implementations here they do the -- There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inverting the order of things as suggested would also improve the readability of the code I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just submitted a new commit inverting the order of things. I have commented the code to make it clearer: basically it starts from a reference point, and tries the next possible updated point, retrying with a smaller learning rate if needed, until it finds a new reference point which has a better objective |
||
|
||
# check for convergence | ||
if it > self.min_iter and abs(delta_obj) < self.convergence_tol: | ||
|
@@ -177,7 +192,7 @@ def fit(self, X, y): | |
print("LMNN didn't converge in %d steps." % self.max_iter) | ||
|
||
# store the last L | ||
self.L_ = L | ||
self.L_ = L_old | ||
self.n_iter_ = it | ||
return self | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
import unittest | ||
import re | ||
import sys | ||
import numpy as np | ||
from six.moves import xrange | ||
from sklearn.externals.six import StringIO | ||
from sklearn.metrics import pairwise_distances | ||
from sklearn.datasets import load_iris | ||
from sklearn.datasets import load_iris, make_classification | ||
from numpy.testing import assert_array_almost_equal | ||
|
||
from metric_learn import ( | ||
|
@@ -70,6 +73,40 @@ def test_iris(self): | |
csep = class_separation(lmnn.transform(), self.iris_labels) | ||
self.assertLess(csep, 0.25) | ||
|
||
def test_convergence_simple_example(self): | ||
# LMNN should converge on this simple example, which it did not with | ||
# this issue: https://github.com/metric-learn/metric-learn/issues/88 | ||
X, y = make_classification(random_state=0) | ||
old_stdout = sys.stdout | ||
sys.stdout = StringIO() | ||
lmnn = LMNN(verbose=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
try: | ||
lmnn.fit(X, y) | ||
finally: | ||
out = sys.stdout.getvalue() | ||
sys.stdout.close() | ||
sys.stdout = old_stdout | ||
assert ("LMNN converged with objective" in out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style nit: the parens here aren't needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
def test_no_twice_same_objective(self): | ||
# test that the objective function never has twice the same value | ||
# see https://github.com/metric-learn/metric-learn/issues/88 | ||
X, y = make_classification(random_state=0) | ||
old_stdout = sys.stdout | ||
sys.stdout = StringIO() | ||
lmnn = LMNN(verbose=True) | ||
try: | ||
lmnn.fit(X, y) | ||
finally: | ||
out = sys.stdout.getvalue() | ||
sys.stdout.close() | ||
sys.stdout = old_stdout | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be nice to have this logic in a context manager. See https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. Also I just found there exist something in pytest that seem to do the job quite nicely: https://docs.pytest.org/en/3.2.1/capture.html |
||
lines = re.split("\n+", out) | ||
objectives = [re.search("\d* (?:(\d*.\d*))[ | -]\d*.\d*", s) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment explaining this regular expression, with an example of what it should be matching. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
for s in lines] | ||
objectives = [match.group(1) for match in objectives if match is not None] | ||
assert len(objectives) == len(set(objectives)) | ||
|
||
|
||
class TestSDML(MetricTestCase): | ||
def test_iris(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comment https://github.com/metric-learn/metric-learn/pull/101/files#r288573007 and #201