diff --git a/.github/workflows/lint_and_format.yml b/.github/workflows/lint_and_format.yml new file mode 100644 index 0000000..fcaf4f4 --- /dev/null +++ b/.github/workflows/lint_and_format.yml @@ -0,0 +1,18 @@ +name: Ruff +on: pull_request +jobs: + lint: + name: Lint, Format, and Commit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 + name: Lint + with: + version: 0.3.5 + args: "check --output-format=full --statistics" + - uses: chartboost/ruff-action@v1 + name: Format + with: + version: 0.3.5 + args: "format --check" diff --git a/README.md b/README.md index 9a1c569..bc04d18 100644 --- a/README.md +++ b/README.md @@ -255,3 +255,10 @@ If you would like to contribute to this project, we recommend following the "for 5. Submit a **Pull request** so that we can review your changes NOTE: Be sure to merge the latest from "upstream" before making a pull request! + +### Checklist Before Pull Request (Optional) + +1. Use `ruff check --fix` to check and fix lint errors +2. Use `ruff format` to apply formatting + +NOTE: Ruff linting and formatting checks are done when PR is raised via Git Action. Before raising a PR, it is a good practice to check and fix lint errors, as well as apply formatting. diff --git a/llmtune/data/dataset_generator.py b/llmtune/data/dataset_generator.py index 6986289..e8eec92 100644 --- a/llmtune/data/dataset_generator.py +++ b/llmtune/data/dataset_generator.py @@ -1,10 +1,10 @@ import os -from os.path import join, exists +import pickle +import re from functools import partial +from os.path import exists, join from typing import Tuple, Union -import pickle -import re from datasets import Dataset from llmtune.data.ingestor import Ingestor, get_ingestor @@ -61,12 +61,8 @@ def _format_one_prompt(self, example, is_test: bool = False): return example def _format_prompts(self): - self.dataset["train"] = self.dataset["train"].map( - partial(self._format_one_prompt, is_test=False) - ) - self.dataset["test"] = self.dataset["test"].map( - partial(self._format_one_prompt, is_test=True) - ) + self.dataset["train"] = self.dataset["train"].map(partial(self._format_one_prompt, is_test=False)) + self.dataset["test"] = self.dataset["test"].map(partial(self._format_one_prompt, is_test=True)) def get_dataset(self) -> Tuple[Dataset, Dataset]: self._train_test_split() diff --git a/llmtune/data/ingestor.py b/llmtune/data/ingestor.py index 227e4d7..3f06c33 100644 --- a/llmtune/data/ingestor.py +++ b/llmtune/data/ingestor.py @@ -1,9 +1,8 @@ +import csv from abc import ABC, abstractmethod -from functools import partial import ijson -import csv -from datasets import Dataset, load_dataset, concatenate_datasets +from datasets import Dataset, concatenate_datasets, load_dataset def get_ingestor(data_type: str): @@ -14,9 +13,7 @@ def get_ingestor(data_type: str): elif data_type == "huggingface": return HuggingfaceIngestor else: - raise ValueError( - f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}" - ) + raise ValueError(f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}") class Ingestor(ABC): diff --git a/llmtune/finetune/lora.py b/llmtune/finetune/lora.py index da751fe..e74667c 100644 --- a/llmtune/finetune/lora.py +++ b/llmtune/finetune/lora.py @@ -1,31 +1,26 @@ -from os.path import join, exists -from typing import Tuple - -import torch +from os.path import join import bitsandbytes as bnb +import torch from datasets import Dataset +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, +) from transformers import ( - AutoTokenizer, AutoModelForCausalLM, - BitsAndBytesConfig, - TrainingArguments, AutoTokenizer, + BitsAndBytesConfig, ProgressCallback, -) -from peft import ( - prepare_model_for_kbit_training, - get_peft_model, - LoraConfig, + TrainingArguments, ) from trl import SFTTrainer -from rich.console import Console - -from llmtune.pydantic_models.config_model import Config -from llmtune.utils.save_utils import DirectoryHelper from llmtune.finetune.generics import Finetune +from llmtune.pydantic_models.config_model import Config from llmtune.ui.rich_ui import RichUI +from llmtune.utils.save_utils import DirectoryHelper class LoRAFinetune(Finetune): @@ -99,9 +94,7 @@ def _inject_lora(self): self.model = get_peft_model(self.model, self._lora_config) if not self.config.accelerate: - self.optimizer = bnb.optim.Adam8bit( - self.model.parameters(), lr=self._training_args.learning_rate - ) + self.optimizer = bnb.optim.Adam8bit(self.model.parameters(), lr=self._training_args.learning_rate) self.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer) if self.config.accelerate: self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( @@ -132,7 +125,7 @@ def finetune(self, train_dataset: Dataset): **self._sft_args.model_dump(), ) - trainer_stats = self._trainer.train() + self._trainer.train() def save_model(self) -> None: self._trainer.model.save_pretrained(self._weights_path) diff --git a/llmtune/inference/lora.py b/llmtune/inference/lora.py index c585f35..720822c 100644 --- a/llmtune/inference/lora.py +++ b/llmtune/inference/lora.py @@ -1,20 +1,18 @@ +import csv import os from os.path import join from threading import Thread -import csv -from transformers import TextIteratorStreamer -from rich.text import Text +import torch from datasets import Dataset -from transformers import AutoTokenizer, BitsAndBytesConfig from peft import AutoPeftModelForCausalLM -import torch - +from rich.text import Text +from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer -from llmtune.pydantic_models.config_model import Config -from llmtune.utils.save_utils import DirectoryHelper from llmtune.inference.generics import Inference +from llmtune.pydantic_models.config_model import Config from llmtune.ui.rich_ui import RichUI +from llmtune.utils.save_utils import DirectoryHelper # TODO: Add type hints please! @@ -35,9 +33,7 @@ def __init__( self.device_map = self.config.model.device_map self._weights_path = dir_helper.save_paths.weights - self.model, self.tokenizer = self._get_merged_model( - dir_helper.save_paths.weights - ) + self.model, self.tokenizer = self._get_merged_model(dir_helper.save_paths.weights) def _get_merged_model(self, weights_path: str): # purge VRAM @@ -47,20 +43,14 @@ def _get_merged_model(self, weights_path: str): dtype = ( torch.float16 if self.config.training.training_args.fp16 - else ( - torch.bfloat16 - if self.config.training.training_args.bf16 - else torch.float32 - ) + else (torch.bfloat16 if self.config.training.training_args.bf16 else torch.float32) ) self.model = AutoPeftModelForCausalLM.from_pretrained( weights_path, torch_dtype=dtype, device_map=self.device_map, - quantization_config=( - BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump()) - ), + quantization_config=(BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump())), ) """TODO: figure out multi-gpu @@ -70,9 +60,7 @@ def _get_merged_model(self, weights_path: str): model = self.model.merge_and_unload() - tokenizer = AutoTokenizer.from_pretrained( - self._weights_path, device_map=self.device_map - ) + tokenizer = AutoTokenizer.from_pretrained(self._weights_path, device_map=self.device_map) return model, tokenizer @@ -83,13 +71,11 @@ def infer_all(self): # inference loop for idx, (prompt, label) in enumerate(zip(prompts, labels)): - RichUI.inference_ground_truth_display( - f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label - ) + RichUI.inference_ground_truth_display(f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label) try: result = self.infer_one(prompt) - except: + except Exception: continue results.append((prompt, label, result)) @@ -103,9 +89,7 @@ def infer_all(self): writer.writerow(row) def infer_one(self, prompt: str) -> str: - input_ids = self.tokenizer( - prompt, return_tensors="pt", truncation=True - ).input_ids.cuda() + input_ids = self.tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda() # stream processor streamer = TextIteratorStreamer( @@ -115,9 +99,7 @@ def infer_one(self, prompt: str) -> str: timeout=60, # 60 sec timeout for generation; to handle OOM errors ) - generation_kwargs = dict( - input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump() - ) + generation_kwargs = dict(input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump()) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() diff --git a/llmtune/pydantic_models/config_model.py b/llmtune/pydantic_models/config_model.py index e2f5617..755c60c 100644 --- a/llmtune/pydantic_models/config_model.py +++ b/llmtune/pydantic_models/config_model.py @@ -1,27 +1,21 @@ -from typing import Literal, Union, List, Dict, Optional -from pydantic import BaseModel, FilePath, validator, Field - -from huggingface_hub.utils import validate_repo_id +from typing import List, Literal, Optional, Union import torch +from pydantic import BaseModel, Field, FilePath, validator + # TODO: Refactor this into multiple files... HfModelPath = str + class QaConfig(BaseModel): - llm_tests: Optional[List[str]] = Field([], description = "list of tests that needs to be connected") - + llm_tests: Optional[List[str]] = Field([], description="list of tests that needs to be connected") + class DataConfig(BaseModel): - file_type: Literal["json", "csv", "huggingface"] = Field( - None, description="File type" - ) - path: Union[FilePath, HfModelPath] = Field( - None, description="Path to the file or HuggingFace model" - ) - prompt: str = Field( - None, description="Prompt for the model. Use {} brackets for column name" - ) + file_type: Literal["json", "csv", "huggingface"] = Field(None, description="File type") + path: Union[FilePath, HfModelPath] = Field(None, description="Path to the file or HuggingFace model") + prompt: str = Field(None, description="Prompt for the model. Use {} brackets for column name") prompt_stub: str = Field( None, description="Stub for the prompt; this is injected during training. Use {} brackets for column name", @@ -48,9 +42,7 @@ class DataConfig(BaseModel): class BitsAndBytesConfig(BaseModel): - load_in_8bit: Optional[bool] = Field( - False, description="Enable 8-bit quantization with LLM.int8()" - ) + load_in_8bit: Optional[bool] = Field(False, description="Enable 8-bit quantization with LLM.int8()") llm_int8_threshold: Optional[float] = Field( 6.0, description="Outlier threshold for outlier detection in 8-bit quantization" ) @@ -61,9 +53,7 @@ class BitsAndBytesConfig(BaseModel): False, description="Enable splitting model parts between int8 on GPU and fp32 on CPU", ) - llm_int8_has_fp16_weight: Optional[bool] = Field( - False, description="Run LLM.int8() with 16-bit main weights" - ) + llm_int8_has_fp16_weight: Optional[bool] = Field(False, description="Run LLM.int8() with 16-bit main weights") load_in_4bit: Optional[bool] = Field( True, @@ -86,14 +76,10 @@ class ModelConfig(BaseModel): "NousResearch/Llama-2-7b-hf", description="Path to the model (huggingface repo or local path)", ) - device_map: Optional[str] = Field( - "auto", description="device onto which to load the model" - ) + device_map: Optional[str] = Field("auto", description="device onto which to load the model") quantize: Optional[bool] = Field(False, description="Flag to enable quantization") - bitsandbytes: BitsAndBytesConfig = Field( - None, description="Bits and Bytes configuration" - ) + bitsandbytes: BitsAndBytesConfig = Field(None, description="Bits and Bytes configuration") # @validator("hf_model_ckpt") # def validate_model(cls, v, **kwargs): @@ -116,22 +102,12 @@ def set_device_map_to_none(cls, v, values, **kwargs): class LoraConfig(BaseModel): r: Optional[int] = Field(8, description="Lora rank") - task_type: Optional[str] = Field( - "CAUSAL_LM", description="Base Model task type during training" - ) + task_type: Optional[str] = Field("CAUSAL_LM", description="Base Model task type during training") - lora_alpha: Optional[int] = Field( - 16, description="The alpha parameter for Lora scaling" - ) - bias: Optional[str] = Field( - "none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'" - ) - lora_dropout: Optional[float] = Field( - 0.1, description="The dropout probability for Lora layers" - ) - target_modules: Optional[List[str]] = Field( - None, description="The names of the modules to apply Lora to" - ) + lora_alpha: Optional[int] = Field(16, description="The alpha parameter for Lora scaling") + bias: Optional[str] = Field("none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'") + lora_dropout: Optional[float] = Field(0.1, description="The dropout probability for Lora layers") + target_modules: Optional[List[str]] = Field(None, description="The names of the modules to apply Lora to") fan_in_fan_out: Optional[bool] = Field( False, description="Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)", @@ -140,9 +116,7 @@ class LoraConfig(BaseModel): None, description="List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint", ) - layers_to_transform: Optional[Union[List[int], int]] = Field( - None, description="The layer indexes to transform" - ) + layers_to_transform: Optional[Union[List[int], int]] = Field(None, description="The layer indexes to transform") layers_pattern: Optional[str] = Field(None, description="The layer pattern name") # rank_pattern: Optional[Dict[str, int]] = Field( # {}, description="The mapping from layer names or regexp expression to ranks" @@ -155,15 +129,9 @@ class LoraConfig(BaseModel): # TODO: Get comprehensive Args! class TrainingArgs(BaseModel): num_train_epochs: Optional[int] = Field(1, description="Number of training epochs") - per_device_train_batch_size: Optional[int] = Field( - 1, description="Batch size per training device" - ) - gradient_accumulation_steps: Optional[int] = Field( - 1, description="Number of steps for gradient accumulation" - ) - gradient_checkpointing: Optional[bool] = Field( - True, description="Flag to enable gradient checkpointing" - ) + per_device_train_batch_size: Optional[int] = Field(1, description="Batch size per training device") + gradient_accumulation_steps: Optional[int] = Field(1, description="Number of steps for gradient accumulation") + gradient_checkpointing: Optional[bool] = Field(True, description="Flag to enable gradient checkpointing") optim: Optional[str] = Field("paged_adamw_32bit", description="Optimizer") logging_steps: Optional[int] = Field(100, description="Number of logging steps") learning_rate: Optional[float] = Field(2.0e-4, description="Learning rate") @@ -172,9 +140,7 @@ class TrainingArgs(BaseModel): fp16: Optional[bool] = Field(False, description="Flag to enable fp16") max_grad_norm: Optional[float] = Field(0.3, description="Maximum gradient norm") warmup_ratio: Optional[float] = Field(0.03, description="Warmup ratio") - lr_scheduler_type: Optional[str] = Field( - "constant", description="Learning rate scheduler type" - ) + lr_scheduler_type: Optional[str] = Field("constant", description="Learning rate scheduler type") # TODO: Get comprehensive Args! diff --git a/llmtune/qa/generics.py b/llmtune/qa/generics.py index e44639c..4bd9b0e 100644 --- a/llmtune/qa/generics.py +++ b/llmtune/qa/generics.py @@ -1,9 +1,10 @@ +import statistics from abc import ABC, abstractmethod -from typing import Union, List, Tuple, Dict +from typing import Dict, List, Union + import pandas as pd + from llmtune.ui.rich_ui import RichUI -import statistics -from llmtune.qa.qa_tests import * class LLMQaTest(ABC): @@ -13,9 +14,7 @@ def test_name(self) -> str: pass @abstractmethod - def get_metric( - self, prompt: str, grount_truth: str, model_pred: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int, bool]: pass @@ -32,7 +31,7 @@ def inner_wrapper(wrapped_class): return inner_wrapper @classmethod - def create_tests_from_list(cls, test_name: str) -> List[LLMQaTest]: + def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]: return [cls.create_test(test) for test in test_names] @@ -44,7 +43,6 @@ def __init__( ground_truths: List[str], model_preds: List[str], ) -> None: - self.tests = tests self.prompts = prompts self.ground_truths = ground_truths @@ -56,9 +54,7 @@ def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]: test_results = {} for test in zip(self.tests): metrics = [] - for prompt, ground_truth, model_pred in zip( - self.prompts, self.ground_truths, self.model_preds - ): + for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds): metrics.append(test.get_metric(prompt, ground_truth, model_pred)) test_results[test.test_name] = metrics @@ -71,19 +67,12 @@ def test_results(self): def print_test_results(self): result_dictionary = self.test_results() - column_data = { - key: [value for value in result_dictionary[key]] - for key in result_dictionary - } + column_data = {key: list(result_dictionary[key]) for key in result_dictionary} mean_values = {key: statistics.mean(column_data[key]) for key in column_data} - median_values = { - key: statistics.median(column_data[key]) for key in column_data - } + median_values = {key: statistics.median(column_data[key]) for key in column_data} stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data} # Use the RichUI class to display the table - RichUI.display_table( - result_dictionary, mean_values, median_values, stdev_values - ) + RichUI.display_table(result_dictionary, mean_values, median_values, stdev_values) def save_test_results(self, path: str): # TODO: save these! diff --git a/llmtune/qa/qa_tests.py b/llmtune/qa/qa_tests.py index bee45f8..c9f3f6f 100644 --- a/llmtune/qa/qa_tests.py +++ b/llmtune/qa/qa_tests.py @@ -1,14 +1,16 @@ -from llmtune.qa.generics import LLMQaTest -from typing import Union, List, Tuple, Dict -import torch -from transformers import DistilBertModel, DistilBertTokenizer +from typing import List, Union + import nltk import numpy as np -from rouge_score import rouge_scorer +import torch +from nltk import pos_tag from nltk.corpus import stopwords from nltk.tokenize import word_tokenize -from nltk import pos_tag -from llmtune.qa.generics import TestRegistry +from rouge_score import rouge_scorer +from transformers import DistilBertModel, DistilBertTokenizer + +from llmtune.qa.generics import LLMQaTest, TestRegistry + model_name = "distilbert-base-uncased" tokenizer = DistilBertTokenizer.from_pretrained(model_name) @@ -25,9 +27,7 @@ class LengthTest(LLMQaTest): def test_name(self) -> str: return "summary_length" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: return abs(len(ground_truth) - len(model_prediction)) @@ -37,9 +37,7 @@ class JaccardSimilarityTest(LLMQaTest): def test_name(self) -> str: return "jaccard_similarity" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: set_ground_truth = set(ground_truth.lower()) set_model_prediction = set(model_prediction.lower()) @@ -62,14 +60,10 @@ def _encode_sentence(self, sentence): outputs = model(**tokens) return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: embedding_ground_truth = self._encode_sentence(ground_truth) embedding_model_prediction = self._encode_sentence(model_prediction) - dot_product_similarity = np.dot( - embedding_ground_truth, embedding_model_prediction - ) + dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction) return dot_product_similarity @@ -79,9 +73,7 @@ class RougeScoreTest(LLMQaTest): def test_name(self) -> str: return "rouge_score" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) scores = scorer.score(model_prediction, ground_truth) return float(scores["rouge1"].precision) @@ -99,9 +91,7 @@ def _remove_stopwords(self, text: str) -> str: filtered_text = [word for word in word_tokens if word.lower() not in stop_words] return " ".join(filtered_text) - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: cleaned_model_prediction = self._remove_stopwords(model_prediction) cleaned_ground_truth = self._remove_stopwords(ground_truth) @@ -128,12 +118,8 @@ class VerbPercent(PosCompositionTest): def test_name(self) -> str: return "verb_percent" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> float: - return self._get_pos_percent( - model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"] - ) + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: + return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]) @TestRegistry.register("adjective_percent") @@ -142,9 +128,7 @@ class AdjectivePercent(PosCompositionTest): def test_name(self) -> str: return "adjective_percent" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> float: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"]) @@ -154,9 +138,7 @@ class NounPercent(PosCompositionTest): def test_name(self) -> str: return "noun_percent" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> float: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"]) diff --git a/llmtune/ui/generics.py b/llmtune/ui/generics.py index 59d5997..87d8fee 100644 --- a/llmtune/ui/generics.py +++ b/llmtune/ui/generics.py @@ -61,7 +61,7 @@ def finetune_found(weights_path: str): pass """ - INFERENCE + INFERENCE """ # Lifecycle functions @@ -91,7 +91,7 @@ def inference_stream_display(text: Text): pass """ - QA + QA """ # Lifecycle functions diff --git a/llmtune/ui/rich_ui.py b/llmtune/ui/rich_ui.py index f138f4c..b8d71bc 100644 --- a/llmtune/ui/rich_ui.py +++ b/llmtune/ui/rich_ui.py @@ -1,15 +1,15 @@ from datasets import Dataset - from rich.console import Console from rich.layout import Layout +from rich.live import Live from rich.panel import Panel from rich.table import Table -from rich.live import Live from rich.text import Text from llmtune.ui.generics import UI from llmtune.utils.rich_print_utils import inject_example_to_rich_layout + console = Console() @@ -25,9 +25,7 @@ def __enter__(self): return self # This allows you to use variables from this context if needed def __exit__(self, exc_type, exc_val, exc_tb): - self.task.__exit__( - exc_type, exc_val, exc_tb - ) # Cleanly exit the console status context + self.task.__exit__(exc_type, exc_val, exc_tb) # Cleanly exit the console status context class LiveContext: @@ -47,9 +45,7 @@ def __enter__(self): return self # This allows you to use variables from this context if needed def __exit__(self, exc_type, exc_val, exc_tb): - self.task.__exit__( - exc_type, exc_val, exc_tb - ) # Cleanly exit the console status context + self.task.__exit__(exc_type, exc_val, exc_tb) # Cleanly exit the console status context def update(self, new_text: Text): self.task.update(new_text) @@ -72,7 +68,7 @@ def during_dataset_creation(message: str, spinner: str): @staticmethod def after_dataset_creation(save_dir: str, train: Dataset, test: Dataset): console.print(f"Dataset Saved at {save_dir}") - console.print(f"Post-Split data size:") + console.print("Post-Split data size:") console.print(f"Train: {len(train)}") console.print(f"Test: {len(test)}") @@ -93,9 +89,7 @@ def dataset_display_one_example(train_row: dict, test_row: dict): ) inject_example_to_rich_layout(layout["train"], "Train Example", train_row) - inject_example_to_rich_layout( - layout["inference"], "Inference Example", test_row - ) + inject_example_to_rich_layout(layout["inference"], "Inference Example", test_row) console.print(layout) @@ -122,14 +116,14 @@ def during_finetune(): @staticmethod def after_finetune(): - console.print(f"Finetuning complete!") + console.print("Finetuning complete!") @staticmethod def finetune_found(weights_path: str): console.print(f"Fine-Tuned Model Found at {weights_path}... skipping training") """ - INFERENCE + INFERENCE """ # Lifecycle functions @@ -167,7 +161,7 @@ def inference_stream_display(text: Text): return LiveContext(text) """ - QA + QA """ # Lifecycle functions @@ -188,10 +182,7 @@ def qa_found(): pass @staticmethod - def qa_display_table( - self, result_dictionary, mean_values, median_values, stdev_values - ): - + def qa_display_table(self, result_dictionary, mean_values, median_values, stdev_values): # Create a table table = Table(show_header=True, header_style="bold", title="Test Results") diff --git a/llmtune/utils/ablation_utils.py b/llmtune/utils/ablation_utils.py index 37d9d80..062892e 100644 --- a/llmtune/utils/ablation_utils.py +++ b/llmtune/utils/ablation_utils.py @@ -1,9 +1,6 @@ import copy import itertools -from typing import List, Type, Any, Dict, Optional, Union, Tuple -from typing import get_args, get_origin, get_type_hints - -import yaml +from typing import Dict, Tuple, Union, get_args, get_origin # TODO: organize this a little bit. It's a bit of a mess rn. @@ -14,17 +11,11 @@ """ -def get_types_from_dict( - source_dict: dict, root="", type_dict={} -) -> Dict[str, Tuple[type, type]]: +def get_types_from_dict(source_dict: dict, root="", type_dict={}) -> Dict[str, Tuple[type, type]]: for key, val in source_dict.items(): - if type(val) is not dict: + if not isinstance(val, dict): attr = f"{root}.{key}" if root else key - tp = ( - (type(val), None) - if type(val) is not list - else (type(val), type(val[0])) - ) + tp = (type(val), None) if not isinstance(val, list) else (type(val), type(val[0])) type_dict[attr] = tp else: join_array = [root, key] if root else [key] diff --git a/llmtune/utils/rich_print_utils.py b/llmtune/utils/rich_print_utils.py index 371f742..d39cc4c 100644 --- a/llmtune/utils/rich_print_utils.py +++ b/llmtune/utils/rich_print_utils.py @@ -1,7 +1,7 @@ -from rich.panel import Panel from rich.layout import Layout -from rich.text import Text +from rich.panel import Panel from rich.table import Table +from rich.text import Text def inject_example_to_rich_layout(layout: Layout, layout_name: str, example: dict): diff --git a/llmtune/utils/save_utils.py b/llmtune/utils/save_utils.py index dd8879b..ef5d0e6 100644 --- a/llmtune/utils/save_utils.py +++ b/llmtune/utils/save_utils.py @@ -4,20 +4,19 @@ 2. Check if files are present at various experiment stages """ -import shutil +import hashlib import os -from os.path import exists -import yaml - import re -import hashlib -from functools import cached_property from dataclasses import dataclass +from functools import cached_property +from os.path import exists +import yaml from sqids import Sqids from llmtune.pydantic_models.config_model import Config + NUM_MD5_DIGITS_FOR_SQIDS = 5 # TODO: maybe move consts to a dedicated folder diff --git a/poetry.lock b/poetry.lock index c65c484..db361f2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -291,52 +291,6 @@ files = [ [package.dependencies] scipy = "*" -[[package]] -name = "black" -version = "24.3.0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"}, - {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"}, - {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"}, - {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"}, - {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"}, - {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"}, - {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"}, - {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"}, - {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"}, - {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"}, - {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"}, - {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"}, - {file = "black-24.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837"}, - {file = "black-24.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd"}, - {file = "black-24.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213"}, - {file = "black-24.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959"}, - {file = "black-24.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb"}, - {file = "black-24.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7"}, - {file = "black-24.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7"}, - {file = "black-24.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f"}, - {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"}, - {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "brotli" version = "1.1.0" @@ -2082,17 +2036,6 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] -[[package]] -name = "pathspec" -version = "0.12.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.8" -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "peft" version = "0.8.2" @@ -2136,21 +2079,6 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" -[[package]] -name = "platformdirs" -version = "4.2.0" -description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -optional = false -python-versions = ">=3.8" -files = [ - {file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"}, - {file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"}, -] - -[package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] - [[package]] name = "prompt-toolkit" version = "3.0.43" @@ -3026,6 +2954,32 @@ nltk = "*" numpy = "*" six = ">=1.14.0" +[[package]] +name = "ruff" +version = "0.3.5" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:aef5bd3b89e657007e1be6b16553c8813b221ff6d92c7526b7e0227450981eac"}, + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:89b1e92b3bd9fca249153a97d23f29bed3992cff414b222fcd361d763fc53f12"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e55771559c89272c3ebab23326dc23e7f813e492052391fe7950c1a5a139d89"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dabc62195bf54b8a7876add6e789caae0268f34582333cda340497c886111c39"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a05f3793ba25f194f395578579c546ca5d83e0195f992edc32e5907d142bfa3"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:dfd3504e881082959b4160ab02f7a205f0fadc0a9619cc481982b6837b2fd4c0"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87258e0d4b04046cf1d6cc1c56fadbf7a880cc3de1f7294938e923234cf9e498"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:712e71283fc7d9f95047ed5f793bc019b0b0a29849b14664a60fd66c23b96da1"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532a90b4a18d3f722c124c513ffb5e5eaff0cc4f6d3aa4bda38e691b8600c9f"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:122de171a147c76ada00f76df533b54676f6e321e61bd8656ae54be326c10296"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d80a6b18a6c3b6ed25b71b05eba183f37d9bc8b16ace9e3d700997f00b74660b"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a7b6e63194c68bca8e71f81de30cfa6f58ff70393cf45aab4c20f158227d5936"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a759d33a20c72f2dfa54dae6e85e1225b8e302e8ac655773aff22e542a300985"}, + {file = "ruff-0.3.5-py3-none-win32.whl", hash = "sha256:9d8605aa990045517c911726d21293ef4baa64f87265896e491a05461cae078d"}, + {file = "ruff-0.3.5-py3-none-win_amd64.whl", hash = "sha256:dc56bb16a63c1303bd47563c60482a1512721053d93231cf7e9e1c6954395a0e"}, + {file = "ruff-0.3.5-py3-none-win_arm64.whl", hash = "sha256:faeeae9905446b975dcf6d4499dc93439b131f1443ee264055c5716dd947af55"}, + {file = "ruff-0.3.5.tar.gz", hash = "sha256:a067daaeb1dc2baf9b82a32dae67d154d95212080c80435eb052d95da647763d"}, +] + [[package]] name = "safetensors" version = "0.4.2" @@ -4378,4 +4332,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9, <=3.12" -content-hash = "32b85b0200f6dab57cb9f4c57fb47cd28bbf41e6077f88428331b18239e0f7e1" +content-hash = "5b9a7244db14a2307b67a46d6278a720635eb08ed1a23a5fa643703b189f6053" diff --git a/pyproject.toml b/pyproject.toml index 76e72a6..caf9203 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,30 @@ shellingham = "^1.5.4" [tool.poetry.group.dev.dependencies] -black = "^24.3.0" +ruff = "~0.3.5" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +lint.ignore = ["C901", "E501", "E741", "F402", "F823" ] +lint.select = ["C", "E", "F", "I", "W"] +line-length = 119 +exclude = [ + "llama2", + "mistral", +] + + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["llmtune"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + + diff --git a/toolkit.py b/toolkit.py index 0ccb91b..2f807a5 100644 --- a/toolkit.py +++ b/toolkit.py @@ -1,20 +1,21 @@ -from os import listdir -from os.path import join, exists -import yaml import logging +from os import listdir +from os.path import exists, join -from transformers import utils as hf_utils -from pydantic import ValidationError import torch import typer +import yaml +from pydantic import ValidationError +from transformers import utils as hf_utils -from llmtune.pydantic_models.config_model import Config from llmtune.data.dataset_generator import DatasetGenerator -from llmtune.utils.save_utils import DirectoryHelper -from llmtune.utils.ablation_utils import generate_permutations from llmtune.finetune.lora import LoRAFinetune from llmtune.inference.lora import LoRAInference +from llmtune.pydantic_models.config_model import Config from llmtune.ui.rich_ui import RichUI +from llmtune.utils.ablation_utils import generate_permutations +from llmtune.utils.save_utils import DirectoryHelper + hf_utils.logging.set_verbosity_error() torch._logging.set_logs(all=logging.CRITICAL) @@ -32,7 +33,7 @@ def run_one_experiment(config: Config, config_path: str) -> None: with RichUI.during_dataset_creation("Injecting Values into Prompt", "monkey"): dataset_generator = DatasetGenerator(**config.data.model_dump()) - train_columns = dataset_generator.train_columns + _ = dataset_generator.train_columns test_column = dataset_generator.test_column dataset_path = dir_helper.save_paths.dataset @@ -66,9 +67,8 @@ def run_one_experiment(config: Config, config_path: str) -> None: results_path = dir_helper.save_paths.results results_file_path = join(dir_helper.save_paths.results, "results.csv") if not exists(results_path) or exists(results_file_path): - inference_runner = LoRAInference( - test, test_column, config, dir_helper - ).infer_all() + inference_runner = LoRAInference(test, test_column, config, dir_helper) + inference_runner.infer_all() RichUI.after_inference(results_path) else: RichUI.inference_found(results_path) @@ -90,9 +90,7 @@ def run(config_path: str = "./config.yml") -> None: with open(config_path, "r") as file: config = yaml.safe_load(file) configs = ( - generate_permutations(config, Config) - if config.get("ablation", {}).get("use_ablate", False) - else [config] + generate_permutations(config, Config) if config.get("ablation", {}).get("use_ablate", False) else [config] ) for config in configs: # validate data with pydantic