Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ ENV/
env.bak/
venv.bak/

# Coverage Report
.coverage
/htmlcov
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
test-coverage:
pytest --cov=llmtune tests/

fix-format:
ruff check --fix
ruff format

build-release:
rm -rf dist
rm -rf build
Expand Down
14 changes: 7 additions & 7 deletions llmtune/qa/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.ground_truths = ground_truths
self.model_preds = model_preds

self.test_results = {}
self._results = {}

@staticmethod
def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
Expand All @@ -60,29 +60,29 @@ def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":

def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
test_results = {}
for test in zip(self.tests):
for test in self.tests:
metrics = []
for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds):
metrics.append(test.get_metric(prompt, ground_truth, model_pred))
test_results[test.test_name] = metrics

self.test_results = test_results
self._results = test_results
return test_results

@property
def test_results(self):
return self.test_results if self.test_results else self.run_tests()
return self._results if self._results else self.run_tests()

def print_test_results(self):
result_dictionary = self.test_results()
result_dictionary = self.test_results
column_data = {key: list(result_dictionary[key]) for key in result_dictionary}
mean_values = {key: statistics.mean(column_data[key]) for key in column_data}
median_values = {key: statistics.median(column_data[key]) for key in column_data}
stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data}
# Use the RichUI class to display the table
RichUI.display_table(result_dictionary, mean_values, median_values, stdev_values)
RichUI.qa_display_table(result_dictionary, mean_values, median_values, stdev_values)

def save_test_results(self, path: str):
# TODO: save these!
resultant_dataframe = pd.DataFrame(self.test_results())
resultant_dataframe = pd.DataFrame(self.test_results)
resultant_dataframe.to_csv(path, index=False)
24 changes: 12 additions & 12 deletions llmtune/qa/qa_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rouge_score import rouge_scorer
from transformers import DistilBertModel, DistilBertTokenizer

from llmtune.qa.generics import LLMQaTest, TestRegistry
from llmtune.qa.generics import LLMQaTest, QaTestRegistry


model_name = "distilbert-base-uncased"
Expand All @@ -21,7 +21,7 @@
nltk.download("averaged_perceptron_tagger")


@TestRegistry.register("summary_length")
@QaTestRegistry.register("summary_length")
class LengthTest(LLMQaTest):
@property
def test_name(self) -> str:
Expand All @@ -31,7 +31,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
return abs(len(ground_truth) - len(model_prediction))


@TestRegistry.register("jaccard_similarity")
@QaTestRegistry.register("jaccard_similarity")
class JaccardSimilarityTest(LLMQaTest):
@property
def test_name(self) -> str:
Expand All @@ -45,10 +45,10 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
union_size = len(set_ground_truth.union(set_model_prediction))

similarity = intersection_size / union_size if union_size != 0 else 0
return similarity
return float(similarity)


@TestRegistry.register("dot_product")
@QaTestRegistry.register("dot_product")
class DotProductSimilarityTest(LLMQaTest):
@property
def test_name(self) -> str:
Expand All @@ -64,10 +64,10 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
embedding_ground_truth = self._encode_sentence(ground_truth)
embedding_model_prediction = self._encode_sentence(model_prediction)
dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction)
return dot_product_similarity
return float(dot_product_similarity)


@TestRegistry.register("rouge_score")
@QaTestRegistry.register("rouge_score")
class RougeScoreTest(LLMQaTest):
@property
def test_name(self) -> str:
Expand All @@ -79,7 +79,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
return float(scores["rouge1"].precision)


@TestRegistry.register("word_overlap")
@QaTestRegistry.register("word_overlap")
class WordOverlapTest(LLMQaTest):
@property
def test_name(self) -> str:
Expand All @@ -100,7 +100,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U

common_words = words_model_prediction.intersection(words_ground_truth)
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
return overlap_percentage
return float(overlap_percentage)


class PosCompositionTest(LLMQaTest):
Expand All @@ -112,7 +112,7 @@ def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
return round(len(pos_words) / total_words, 2)


@TestRegistry.register("verb_percent")
@QaTestRegistry.register("verb_percent")
class VerbPercent(PosCompositionTest):
@property
def test_name(self) -> str:
Expand All @@ -122,7 +122,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"])


@TestRegistry.register("adjective_percent")
@QaTestRegistry.register("adjective_percent")
class AdjectivePercent(PosCompositionTest):
@property
def test_name(self) -> str:
Expand All @@ -132,7 +132,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"])


@TestRegistry.register("noun_percent")
@QaTestRegistry.register("noun_percent")
class NounPercent(PosCompositionTest):
@property
def test_name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion llmtune/ui/rich_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def qa_found():
pass

@staticmethod
def qa_display_table(self, result_dictionary, mean_values, median_values, stdev_values):
def qa_display_table(result_dictionary, mean_values, median_values, stdev_values):
# Create a table
table = Table(show_header=True, header_style="bold", title="Test Results")

Expand Down
152 changes: 151 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ shellingham = "^1.5.4"

[tool.poetry.group.dev.dependencies]
ruff = "~0.3.5"
pytest = "^8.1.1"
pytest-cov = "^5.0.0"
pytest-mock = "^3.14.0"

[build-system]
requires = ["poetry-core", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
Expand All @@ -92,4 +95,18 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

[tool.coverage.run]
omit = [
# Ignore UI for now as this might change quite often
"llmtune/ui/*",
"llmtune/utils/rich_print_utils.py"
]

[tool.coverage.report]
skip_empty = true
exclude_also = [
"pass",
]

[tool.pytest.ini_options]
addopts = "--cov=llmtune --cov-report term-missing"
Empty file added test_utils/__init__.py
Empty file.
Loading