Skip to content
31 changes: 23 additions & 8 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def fit(self, X, y):
L = self.L_
objective = np.inf

# we initialize the roll back
L_old = L.copy()
G_old = G.copy()
df_old = df.copy()
a1_old = [a.copy() for a in a1]
a2_old = [a.copy() for a in a2]
objective_old = objective

# main loop
for it in xrange(1, self.max_iter):
df_old = df.copy()
a1_old = [a.copy() for a in a1]
a2_old = [a.copy() for a in a2]
objective_old = objective
# Compute pairwise distances under current metric
Lx = L.dot(self.X_.T).T
g0 = _inplace_paired_L2(*Lx[impostors])
Expand Down Expand Up @@ -158,14 +162,25 @@ def fit(self, X, y):
if delta_obj > 0:
# we're getting worse... roll back!
learn_rate /= 2.0
L = L_old
G = G_old
df = df_old
a1 = a1_old
a2 = a2_old
objective = objective_old
else:
# update L
L -= learn_rate * 2 * L.dot(G)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learn_rate *= 1.01
# We did good. We store this point as reference in case we do
# worse next time.
objective_old = objective
L_old = L.copy()
G_old = G.copy()
df_old = df.copy()
a1_old = [a.copy() for a in a1]
a2_old = [a.copy() for a in a2]
# we update L and will see in the next iteration if it does indeed
# better
L -= learn_rate * 2 * L.dot(G)
learn_rate *= 1.01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong to me, though it was also wrong before this PR.

Reading through the Shogun implementation here and here, they don't do any rolling back.
They do the L update unconditionally in gradient_step, then compute the objective value for the current iteration, then do the learning rate update based on the change in objective.

In one of the reference Matlab implementations here they do the L update first, then optionally roll back to a saved state when updating the step size.

--
So I think the correct fix would be to move the L update to the # do the gradient update section, after computing the new G, using the existing learning rate. Then, if the objective didn't improve we can halve the learning rate and roll back to the last good state (including L and G). Otherwise, we just grow the learning rate by 1% and carry on.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inverting the order of things as suggested would also improve the readability of the code I think

Copy link
Member Author

@wdevazelhes wdevazelhes Jul 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just submitted a new commit inverting the order of things. I have commented the code to make it clearer: basically it starts from a reference point, and tries the next possible updated point, retrying with a smaller learning rate if needed, until it finds a new reference point which has a better objective


# check for convergence
if it > self.min_iter and abs(delta_obj) < self.convergence_tol:
Expand All @@ -177,7 +192,7 @@ def fit(self, X, y):
print("LMNN didn't converge in %d steps." % self.max_iter)

# store the last L
self.L_ = L
self.L_ = L_old
self.n_iter_ = it
return self

Expand Down
39 changes: 38 additions & 1 deletion test/metric_learn_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import unittest
import re
import sys
import numpy as np
from six.moves import xrange
from sklearn.externals.six import StringIO
from sklearn.metrics import pairwise_distances
from sklearn.datasets import load_iris
from sklearn.datasets import load_iris, make_classification
from numpy.testing import assert_array_almost_equal

from metric_learn import (
Expand Down Expand Up @@ -70,6 +73,40 @@ def test_iris(self):
csep = class_separation(lmnn.transform(), self.iris_labels)
self.assertLess(csep, 0.25)

def test_convergence_simple_example(self):
# LMNN should converge on this simple example, which it did not with
# this issue: https://github.com/metric-learn/metric-learn/issues/88
X, y = make_classification(random_state=0)
old_stdout = sys.stdout
sys.stdout = StringIO()
lmnn = LMNN(verbose=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use python_LMNN here and the other test, to not fail when the shogun version is available.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

try:
lmnn.fit(X, y)
finally:
out = sys.stdout.getvalue()
sys.stdout.close()
sys.stdout = old_stdout
assert ("LMNN converged with objective" in out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: the parens here aren't needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def test_no_twice_same_objective(self):
# test that the objective function never has twice the same value
# see https://github.com/metric-learn/metric-learn/issues/88
X, y = make_classification(random_state=0)
old_stdout = sys.stdout
sys.stdout = StringIO()
lmnn = LMNN(verbose=True)
try:
lmnn.fit(X, y)
finally:
out = sys.stdout.getvalue()
sys.stdout.close()
sys.stdout = old_stdout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to have this logic in a context manager. See https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Also I just found there exist something in pytest that seem to do the job quite nicely: https://docs.pytest.org/en/3.2.1/capture.html
But it breaks a bit the structure of unittest classes... If it is important to keep the previous structure I'll use the context manager
Tell me what you think

lines = re.split("\n+", out)
objectives = [re.search("\d* (?:(\d*.\d*))[ | -]\d*.\d*", s)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining this regular expression, with an example of what it should be matching.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for s in lines]
objectives = [match.group(1) for match in objectives if match is not None]
assert len(objectives) == len(set(objectives))


class TestSDML(MetricTestCase):
def test_iris(self):
Expand Down