Skip to content
156 changes: 85 additions & 71 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,83 +90,49 @@ def fit(self, X, y):
a1[nn_idx] = np.array([])
a2[nn_idx] = np.array([])

# initialize gradient and L
G = dfG * reg + df * (1-reg)
# initialize L
L = self.L_
objective = np.inf

# main loop
for it in xrange(1, self.max_iter):
df_old = df.copy()
a1_old = [a.copy() for a in a1]
a2_old = [a.copy() for a in a2]
objective_old = objective
# Compute pairwise distances under current metric
Lx = L.dot(self.X_.T).T
g0 = _inplace_paired_L2(*Lx[impostors])
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:,None,:])
g1,g2 = Ni[impostors]

# compute the gradient
total_active = 0
for nn_idx in reversed(xrange(k)):
act1 = g0 < g1[:,nn_idx]
act2 = g0 < g2[:,nn_idx]
total_active += act1.sum() + act2.sum()

if it > 1:
plus1 = act1 & ~a1[nn_idx]
minus1 = a1[nn_idx] & ~act1
plus2 = act2 & ~a2[nn_idx]
minus2 = a2[nn_idx] & ~act2
else:
plus1 = act1
plus2 = act2
minus1 = np.zeros(0, dtype=int)
minus2 = np.zeros(0, dtype=int)

targets = target_neighbors[:,nn_idx]
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
df += _sum_outer_products(self.X_, PLUS[:,0], PLUS[:,1], pweight)
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
df -= _sum_outer_products(self.X_, MINUS[:,0], MINUS[:,1], mweight)

in_imp, out_imp = impostors
df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1])
df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2])

df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1])
df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2])

a1[nn_idx] = act1
a2[nn_idx] = act2

# do the gradient update
assert not np.isnan(df).any()
G = dfG * reg + df * (1-reg)

# compute the objective function
objective = total_active * (1-reg)
objective += G.flatten().dot(L.T.dot(L).flatten())
assert not np.isnan(objective)
delta_obj = objective - objective_old
# first iteration: we compute variables (including objective and gradient)
# at initialization point
G, objective, total_active, df, a1, a2 = (
self._loss_grad(L, dfG, impostors, 1, k, reg, target_neighbors, df, a1,
a2))

for it in xrange(2, self.max_iter):
# then at each iteration, we try to find a value of L that has better
# objective than the previous L, following the gradient:
while True:
# the next point next_L to try out is found by a gradient step
L_next = L - 2 * learn_rate * G
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it should be 2*learn_rate*L.dot(G), not 2*learn_rate*G... (see #201)

# we compute the objective at next point
# we copy variables that can be modified by _loss_grad, because if we
# retry we don t want to modify them several times
(G_next, objective_next, total_active_next, df_next, a1_next,
a2_next) = (
self._loss_grad(L_next, dfG, impostors, it, k, reg,
target_neighbors, df.copy(), list(a1), list(a2)))
assert not np.isnan(objective)
delta_obj = objective_next - objective
if delta_obj > 0:
# if we did not find a better objective, we retry with an L closer to
# the starting point, by decreasing the learning rate (making the
# gradient step smaller)
learn_rate /= 2
else:
# otherwise, if we indeed found a better obj, we get out of the loop
break
# when the better L is found (and the related variables), we set the
# old variables to these new ones before next iteration and we
# slightly increase the learning rate
L = L_next
G, df, objective, total_active, a1, a2 = (
G_next, df_next, objective_next, total_active_next, a1_next, a2_next)
learn_rate *= 1.01

if self.verbose:
print(it, objective, delta_obj, total_active, learn_rate)

# update step size
if delta_obj > 0:
# we're getting worse... roll back!
learn_rate /= 2.0
df = df_old
a1 = a1_old
a2 = a2_old
objective = objective_old
else:
# update L
L -= learn_rate * 2 * L.dot(G)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learn_rate *= 1.01

# check for convergence
if it > self.min_iter and abs(delta_obj) < self.convergence_tol:
if self.verbose:
Expand All @@ -181,6 +147,54 @@ def fit(self, X, y):
self.n_iter_ = it
return self

def _loss_grad(self, L, dfG, impostors, it, k, reg, target_neighbors, df, a1,
a2):
# Compute pairwise distances under current metric
Lx = L.dot(self.X_.T).T
g0 = _inplace_paired_L2(*Lx[impostors])
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :])
g1, g2 = Ni[impostors]
# compute the gradient
total_active = 0
for nn_idx in reversed(xrange(k)):
act1 = g0 < g1[:, nn_idx]
act2 = g0 < g2[:, nn_idx]
total_active += act1.sum() + act2.sum()

if it > 1:
plus1 = act1 & ~a1[nn_idx]
minus1 = a1[nn_idx] & ~act1
plus2 = act2 & ~a2[nn_idx]
minus2 = a2[nn_idx] & ~act2
else:
plus1 = act1
plus2 = act2
minus1 = np.zeros(0, dtype=int)
minus2 = np.zeros(0, dtype=int)

targets = target_neighbors[:, nn_idx]
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
df += _sum_outer_products(self.X_, PLUS[:, 0], PLUS[:, 1], pweight)
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
df -= _sum_outer_products(self.X_, MINUS[:, 0], MINUS[:, 1], mweight)

in_imp, out_imp = impostors
df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1])
df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2])

df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1])
df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2])

a1[nn_idx] = act1
a2[nn_idx] = act2
# do the gradient update
assert not np.isnan(df).any()
G = dfG * reg + df * (1 - reg)
# compute the objective function
objective = total_active * (1 - reg)
objective += G.flatten().dot(L.T.dot(L).flatten())
return G, objective, total_active, df, a1, a2

def _select_targets(self):
target_neighbors = np.empty((self.X_.shape[0], self.k), dtype=int)
for label in self.labels_:
Expand Down
33 changes: 32 additions & 1 deletion test/metric_learn_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
import unittest
import re
import pytest
import numpy as np
from scipy.optimize import check_grad
Expand Down Expand Up @@ -76,6 +76,37 @@ def test_iris(self):
self.assertLess(csep, 0.25)


def test_convergence_simple_example(capsys):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does capsys get passed in from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be automatically found by pytest (as one of the integrated fixtures), and runned when running test_convergence_simple_example. I verified doing -v (verbose) when doing pytest and these tests are correctly passing when they should (and failing when modifying the error message)

# LMNN should converge on this simple example, which it did not with
# this issue: https://github.com/metric-learn/metric-learn/issues/88
X, y = make_classification(random_state=0)
lmnn = python_LMNN(verbose=True)
lmnn.fit(X, y)
out, _ = capsys.readouterr()
assert "LMNN converged with objective" in out


def test_no_twice_same_objective(capsys):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be methods of TestLMNN?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read here (pytest-dev/pytest#2504 (comment)) that pytest fixtures cannot be integrated with unittest classes, so I extracted the tests from the class hierarchy. But I agree that it is not ideal. They propose a workaround in the link, so maybe it would be better ? (adding these lines to TestLMNN, include the test in TestLMNN, and replace capsys by self.capsys in the test)

@pytest.fixture(autouse=True)
def capsys(self, capsys):
  self.capsys = capsys

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. The current solution is fine, then.

# test that the objective function never has twice the same value
# see https://github.com/metric-learn/metric-learn/issues/88
X, y = make_classification(random_state=0)
lmnn = python_LMNN(verbose=True)
lmnn.fit(X, y)
out, _ = capsys.readouterr()
lines = re.split("\n+", out)
# we get only objectives from each line:
# the regexp matches a float that follows an integer (the iteration
# number), and which is followed by a (signed) float (delta obj). It
# matches for instance:
# 3 **1113.7665747189938** -3.182774197440267 46431.0200999999999998e-06
objectives = [re.search("\d* (?:(\d*.\d*))[ | -]\d*.\d*", s)
for s in lines]
objectives = [match.group(1) for match in objectives if match is not None]
# we remove the last element because it can be equal to the penultimate
# if the last gradient update is null
assert len(objectives[:-1]) == len(set(objectives[:-1]))


class TestSDML(MetricTestCase):
def test_iris(self):
# Note: this is a flaky test, which fails for certain seeds.
Expand Down