Skip to content

Commit 44d8e0f

Browse files
authored
Fix unleved bug with bert (#138)
1 parent a344014 commit 44d8e0f

File tree

4 files changed

+42
-18
lines changed

4 files changed

+42
-18
lines changed

hiclass/HierarchicalClassifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def _pre_fit(self, X, y, sample_weight):
161161
)
162162
else:
163163
self.X_ = np.array(X)
164-
self.y_ = np.array(y)
164+
self.y_ = check_array(
165+
make_leveled(y), dtype=None, ensure_2d=False, allow_nd=True
166+
)
165167

166168
if sample_weight is not None:
167169
self.sample_weight_ = _check_sample_weight(sample_weight, X)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"ray",
4646
"shap==0.44.1",
4747
"xarray==2023.1.0",
48+
"bert-sklearn @ git+https://github.com/charles9n/bert-sklearn.git#egg=bert-sklearn",
4849
],
4950
}
5051

tests/test_LocalClassifierPerParentNode.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import networkx as nx
55
import numpy as np
66
import pytest
7-
from numpy.testing import assert_array_equal, assert_array_almost_equal
7+
from bert_sklearn import BertClassifier
8+
from numpy.testing import assert_array_almost_equal, assert_array_equal
89
from scipy.sparse import csr_matrix
910
from sklearn.exceptions import NotFittedError
1011
from sklearn.linear_model import LogisticRegression
1112
from sklearn.utils.estimator_checks import parametrize_with_checks
1213
from sklearn.utils.validation import check_is_fitted
14+
1315
from hiclass import LocalClassifierPerParentNode
1416
from hiclass._calibration.Calibrator import _Calibrator
1517
from hiclass.HierarchicalClassifier import make_leveled
@@ -393,3 +395,37 @@ def test_fit_calibrate_predict_predict_proba_bert():
393395
classifier.calibrate(x, y)
394396
classifier.predict(x)
395397
classifier.predict_proba(x)
398+
399+
400+
# Note: bert only works with the local classifier per parent node
401+
# It does not have the attribute classes_, which are necessary
402+
# for the local classifiers per level and per node
403+
def test_fit_bert():
404+
bert = BertClassifier()
405+
clf = LocalClassifierPerParentNode(
406+
local_classifier=bert,
407+
bert=True,
408+
)
409+
x = ["Batman", "rorschach"]
410+
y = [
411+
["Action", "The Dark Night"],
412+
["Action", "Watchmen"],
413+
]
414+
clf.fit(x, y)
415+
check_is_fitted(clf)
416+
predictions = clf.predict(x)
417+
assert_array_equal(y, predictions)
418+
419+
420+
def test_bert_unleveled():
421+
clf = LocalClassifierPerParentNode(
422+
local_classifier=BertClassifier(),
423+
bert=True,
424+
)
425+
x = ["Batman", "Jaws"]
426+
y = [["Action", "The Dark Night"], ["Thriller"]]
427+
ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]]
428+
clf.fit(x, y)
429+
check_is_fitted(clf)
430+
predictions = clf.predict(x)
431+
assert_array_equal(ground_truth, predictions)

tests/test_LocalClassifiers.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from sklearn.utils.validation import check_is_fitted
1111

1212
from hiclass import (
13-
LocalClassifierPerNode,
1413
LocalClassifierPerLevel,
14+
LocalClassifierPerNode,
1515
LocalClassifierPerParentNode,
1616
)
1717
from hiclass.ConstantClassifier import ConstantClassifier
@@ -75,21 +75,6 @@ def test_empty_levels(empty_levels, classifier):
7575
assert_array_equal(ground_truth, predictions)
7676

7777

78-
@pytest.mark.parametrize("classifier", classifiers)
79-
def test_fit_bert(classifier):
80-
bert = ConstantClassifier()
81-
clf = classifier(
82-
local_classifier=bert,
83-
bert=True,
84-
)
85-
X = ["Text 1", "Text 2"]
86-
y = ["a", "a"]
87-
clf.fit(X, y)
88-
check_is_fitted(clf)
89-
predictions = clf.predict(X)
90-
assert_array_equal(y, predictions)
91-
92-
9378
@pytest.mark.parametrize("classifier", classifiers)
9479
def test_knn(classifier):
9580
knn = KNeighborsClassifier(

0 commit comments

Comments
 (0)