Skip to content
Merged
12 changes: 12 additions & 0 deletions metric_learn/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,15 @@ def validate_vector(u, dtype=None):
if u.ndim > 1:
raise ValueError("Input vector should be 1-D.")
return u


def _check_num_dims(n_features, num_dims):
"""Checks that num_dims is less that n_features and deal with the None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: "less than"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done

case"""
if num_dims is None:
dim = n_features
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a style nitpick, I prefer early-return for cases like this:

if num_dims is None:
  return n_features
if 0 < num_dims <= n_features:
  return num_dims
raise ValueError(...)

Copy link
Member Author

@wdevazelhes wdevazelhes Apr 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's better I agree, done

else:
if not 0 < num_dims <= n_features:
raise ValueError('Invalid num_dims, must be in [1, %d]' % n_features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some existing code would only warn if num_dims > n_features. I think it's probably better to error out here, but we should keep in mind that this is technically a back-compat break.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, maybe adding it in the changelog is enough ? Something like "for all the algorithms that have a parameter num_dims, it will now be checked to be between 1 and n_features, with n_features the number of dimensions of the input space" ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added it to the release draft

dim = num_dims
return dim
2 changes: 1 addition & 1 deletion metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def get_mahalanobis_matrix(self):

Returns
-------
M : `numpy.ndarray`, shape=(n_components, n_features)
M : `numpy.ndarray`, shape=(n_features, n_features)
The copy of the learned Mahalanobis matrix.
"""
return self.transformer_.T.dot(self.transformer_)
Expand Down
9 changes: 3 additions & 6 deletions metric_learn/lfda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from six.moves import xrange
from sklearn.metrics import pairwise_distances
from sklearn.base import TransformerMixin

from ._util import _check_num_dims
from .base_metric import MahalanobisMixin


Expand Down Expand Up @@ -78,12 +80,7 @@ def fit(self, X, y):
n, d = X.shape
num_classes = len(unique_classes)

if self.num_dims is None:
dim = d
else:
if not 0 < self.num_dims <= d:
raise ValueError('Invalid num_dims, must be in [1,%d]' % d)
dim = self.num_dims
dim = _check_num_dims(d, self.num_dims)

if self.k is None:
k = min(7, d - 1)
Expand Down
8 changes: 6 additions & 2 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from six.moves import xrange
from sklearn.metrics import euclidean_distances
from sklearn.base import TransformerMixin

from ._util import _check_num_dims
from .base_metric import MahalanobisMixin


# commonality between LMNN implementations
class _base_LMNN(MahalanobisMixin, TransformerMixin):
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
regularization=0.5, convergence_tol=0.001, use_pca=True,
verbose=False, preprocessor=None):
verbose=False, preprocessor=None, num_dims=None):
"""Initialize the LMNN object.

Parameters
Expand All @@ -46,6 +48,7 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
self.convergence_tol = convergence_tol
self.use_pca = use_pca
self.verbose = verbose
self.num_dims = num_dims
super(_base_LMNN, self).__init__(preprocessor)


Expand All @@ -60,13 +63,14 @@ def fit(self, X, y):
X, y = self._prepare_inputs(X, y, dtype=float,
ensure_min_samples=2)
num_pts, num_dims = X.shape
output_dim = _check_num_dims(num_dims, self.num_dims)
unique_labels, label_inds = np.unique(y, return_inverse=True)
if len(label_inds) != num_pts:
raise ValueError('Must have one label per point.')
self.labels_ = np.arange(len(unique_labels))
if self.use_pca:
warnings.warn('use_pca does nothing for the python_LMNN implementation')
self.transformer_ = np.eye(num_dims)
self.transformer_ = np.eye(output_dim, num_dims)
required_k = np.bincount(label_inds).min()
if self.k > required_k:
raise ValueError('not enough class labels for specified k'
Expand Down
5 changes: 2 additions & 3 deletions metric_learn/nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.utils.fixes import logsumexp
from sklearn.base import TransformerMixin

from ._util import _check_num_dims
from .base_metric import MahalanobisMixin

EPS = np.finfo(float).eps
Expand Down Expand Up @@ -63,9 +64,7 @@ def fit(self, X, y):
"""
X, labels = self._prepare_inputs(X, y, ensure_min_samples=2)
n, d = X.shape
num_dims = self.num_dims
if num_dims is None:
num_dims = d
num_dims = _check_num_dims(d, self.num_dims)

# Measure the total training time
train_time = time.time()
Expand Down
12 changes: 2 additions & 10 deletions metric_learn/rca.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn import decomposition
from sklearn.base import TransformerMixin

from ._util import _check_num_dims
from .base_metric import MahalanobisMixin
from .constraints import Constraints

Expand Down Expand Up @@ -75,16 +76,7 @@ def _check_dimension(self, rank, X):
'You should adjust pca_comps to remove noise and '
'redundant information.')

if self.num_dims is None:
dim = d
elif self.num_dims <= 0:
raise ValueError('Invalid embedding dimension: must be greater than 0.')
elif self.num_dims > d:
dim = d
warnings.warn('num_dims (%d) must be smaller than '
'the data dimension (%d)' % (self.num_dims, d))
else:
dim = self.num_dims
dim = _check_num_dims(d, self.num_dims)
return dim

def fit(self, X, chunks):
Expand Down
42 changes: 40 additions & 2 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def test_lmnn(self):
self.assertRegexpMatches(
str(metric_learn.LMNN()),
r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, "
r"max_iter=1000,\n min_iter=50, preprocessor=None, "
r"regularization=0.5, use_pca=True,\n verbose=False\)")
r"max_iter=1000,\n min_iter=50, num_dims=None, "
r"preprocessor=None, regularization=0.5,\n use_pca=True, "
r"verbose=False\)")

def test_nca(self):
self.assertEqual(str(metric_learn.NCA()),
Expand Down Expand Up @@ -163,5 +164,42 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset):
assert len(record) == 0


@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
ids=ids_metric_learners)
def test_num_dims(estimator, build_dataset):
"""Check that estimators that have a num_dims parameters can use it
and that it actually works as expected"""
input_data, labels, _, X = build_dataset()
model = clone(estimator)

if hasattr(model, 'num_dims'):
set_random_state(model)
model.set_params(num_dims=None)
model.fit(input_data, labels)
assert model.transformer_.shape == (X.shape[1], X.shape[1])

model = clone(estimator)
set_random_state(model)
model.set_params(num_dims=X.shape[1] - 1)
model.fit(input_data, labels)
assert model.transformer_.shape == (X.shape[1] - 1, X.shape[1])

model = clone(estimator)
set_random_state(model)
model.set_params(num_dims=X.shape[1] + 1)
with pytest.raises(ValueError) as expected_err:
model.fit(input_data, labels)
assert str(expected_err.value) == ('Invalid num_dims, must be in [1, {}]'
.format(X.shape[1]))

model = clone(estimator)
set_random_state(model)
model.set_params(num_dims=0)
with pytest.raises(ValueError) as expected_err:
model.fit(input_data, labels)
assert str(expected_err.value) == ('Invalid num_dims, must be in [1, {}]'
.format(X.shape[1]))


if __name__ == '__main__':
unittest.main()
5 changes: 1 addition & 4 deletions test/test_mahalanobis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ def test_embed_dim(estimator, build_dataset):
model.score_pairs(model.transform(X[0, :]))
assert str(raised_error.value) == err_msg
# we test that the shape is also OK when doing dimensionality reduction
if type(model).__name__ in {'LFDA', 'MLKR', 'NCA', 'RCA'}:
# TODO:
# avoid this enumeration and rather test if hasattr n_components
# as soon as we have made the arguments names as such (issue #167)
if hasattr(model, 'num_dims'):
model.set_params(num_dims=2)
model.fit(*remove_y_quadruplets(estimator, input_data, labels))
assert model.transform(X).shape == (X.shape[0], 2)
Expand Down
19 changes: 18 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metric_learn._util import (check_input, make_context, preprocess_tuples,
make_name, preprocess_points,
check_collapsed_pairs, validate_vector,
_check_sdp_from_eigen)
_check_sdp_from_eigen, _check_num_dims)
from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
MMC_Supervised, RCA_Supervised, SDML_Supervised,
Expand Down Expand Up @@ -1067,3 +1067,20 @@ def test_check_sdp_from_eigen_positive_err_messages():
_check_sdp_from_eigen(w, 1.)
_check_sdp_from_eigen(w, 0.)
_check_sdp_from_eigen(w, None)


def test__check_num_dims():
"""Checks that num_dims returns what is expected (including the errors)"""
dim = _check_num_dims(5, None)
assert dim == 5

dim = _check_num_dims(5, 3)
assert dim == 3

with pytest.raises(ValueError) as expected_err:
_check_num_dims(5, 10)
assert str(expected_err.value) == 'Invalid num_dims, must be in [1, 5]'

with pytest.raises(ValueError) as expected_err:
_check_num_dims(5, 0)
assert str(expected_err.value) == 'Invalid num_dims, must be in [1, 5]'