diff --git a/aeon/clustering/tests/test_base.py b/aeon/clustering/tests/test_base.py index 362ce21e52..3d968946e4 100644 --- a/aeon/clustering/tests/test_base.py +++ b/aeon/clustering/tests/test_base.py @@ -2,8 +2,43 @@ import numpy as np import numpy.random +import pytest from aeon.clustering.base import BaseClusterer +from aeon.testing.mock_estimators import MockCluster + + +def test_correct_input(): + """Tests errors raised with wrong inputs: X and/or y.""" + dummy = MockCluster() + + # list of strings X + X = ["list", "of", "invalid", "test", "strings"] + msg1 = r"ERROR passed a list containing " + with pytest.raises(TypeError, match=msg1): + dummy.fit(X) + + # dict X + X = { + 0: "invalid", + 1: "input", + 2: "dict", + } + msg2 = r"ERROR passed input of type " + with pytest.raises(TypeError, match=msg2): + dummy.fit(X) + + # 2d list of int X + X = [[1, 1, 1], [1, 3, 4]] + msg3 = r"lists should either 2D numpy arrays or pd.DataFrames" + with pytest.raises(TypeError, match=msg3): + dummy.fit(X) + + # correct X + X = np.random.randn(5, 5) + dummy.fit(X) + assert (dummy.predict(X)).shape == (5,) + assert (dummy.predict_proba(X)).shape == (5,) class _TestClusterer(BaseClusterer): diff --git a/aeon/testing/mock_estimators/__init__.py b/aeon/testing/mock_estimators/__init__.py index 2bd1a8727d..e81543bdf7 100644 --- a/aeon/testing/mock_estimators/__init__.py +++ b/aeon/testing/mock_estimators/__init__.py @@ -8,6 +8,7 @@ "MockClassifierPredictProba", "MockClassifierFullTags", "MockClassifierMultiTestParams", + "MockCluster", "MockDeepClusterer", "MockSegmenter", "SupervisedMockSegmenter", @@ -27,7 +28,7 @@ MockClassifierMultiTestParams, MockClassifierPredictProba, ) -from aeon.testing.mock_estimators._mock_clusterers import MockDeepClusterer +from aeon.testing.mock_estimators._mock_clusterers import MockCluster, MockDeepClusterer from aeon.testing.mock_estimators._mock_collection_transformers import ( MockCollectionTransformer, ) diff --git a/aeon/testing/mock_estimators/_mock_clusterers.py b/aeon/testing/mock_estimators/_mock_clusterers.py index d846e1ba18..ff79c192e7 100644 --- a/aeon/testing/mock_estimators/_mock_clusterers.py +++ b/aeon/testing/mock_estimators/_mock_clusterers.py @@ -1,8 +1,32 @@ import numpy as np +from aeon.clustering.base import BaseClusterer from aeon.clustering.deep_learning.base import BaseDeepClusterer +class MockCluster(BaseClusterer): + """Mock Cluster for testing base class fit/predict.""" + + def __init__(self, n_clusters: int = None): + super().__init__(n_clusters) + + def _fit(self, X): + """Mock fit.""" + return self + + def _predict(self, X): + """Mock predict.""" + return np.zeros(len(X)) + + def _predict_proba(self, X): + """Mock predict proba.""" + y = np.random.rand(len(X)) + return y + + def _score(self, X, y): + return np.random.randn(1) + + class MockDeepClusterer(BaseDeepClusterer): """Mock Deep Clusterer for testing empty base deep class save utilities."""