Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/lint_and_format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Ruff
on: pull_request
jobs:
lint:
name: Lint, Format, and Commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
with:
version: 0.3.5
args: check --fix --output-format=full --statistics
- uses: chartboost/ruff-action@v1
with:
version: 0.3.5
args: format
- uses: stefanzweifel/git-auto-commit-action@v5
with:
commit_message: "lint fixes and formatting by ruff"
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,10 @@ If you would like to contribute to this project, we recommend following the "for
5. Submit a **Pull request** so that we can review your changes

NOTE: Be sure to merge the latest from "upstream" before making a pull request!

### Checklist Before Pull Request (Optional)

1. Use `ruff check` to check for lint errors
2. Use `ruff format` to apply formatting

NOTE: Ruff linting and formatting are done automatically when PR is raised using Git Action (and changes will be automatically applied via another commit). It is, however, a good practice to check and fix lint errors, as well as apply formatting before PR.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be nice to add this into Makefile after

6 changes: 3 additions & 3 deletions llmtune/data/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from os.path import join, exists
import pickle
import re
from functools import partial
from os.path import exists, join
from typing import Tuple, Union
import pickle

import re
from datasets import Dataset

from llmtune.data.ingestor import Ingestor, get_ingestor
Expand Down
5 changes: 2 additions & 3 deletions llmtune/data/ingestor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import csv
from abc import ABC, abstractmethod
from functools import partial

import ijson
import csv
from datasets import Dataset, load_dataset, concatenate_datasets
from datasets import Dataset, concatenate_datasets, load_dataset


def get_ingestor(data_type: str):
Expand Down
29 changes: 12 additions & 17 deletions llmtune/finetune/lora.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
from os.path import join, exists
from typing import Tuple

import torch
from os.path import join

import bitsandbytes as bnb
import torch
from datasets import Dataset
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
)
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments,
AutoTokenizer,
BitsAndBytesConfig,
ProgressCallback,
)
from peft import (
prepare_model_for_kbit_training,
get_peft_model,
LoraConfig,
TrainingArguments,
)
from trl import SFTTrainer
from rich.console import Console


from llmtune.pydantic_models.config_model import Config
from llmtune.utils.save_utils import DirectoryHelper
from llmtune.finetune.generics import Finetune
from llmtune.pydantic_models.config_model import Config
from llmtune.ui.rich_ui import RichUI
from llmtune.utils.save_utils import DirectoryHelper


class LoRAFinetune(Finetune):
Expand Down Expand Up @@ -132,7 +127,7 @@ def finetune(self, train_dataset: Dataset):
**self._sft_args.model_dump(),
)

trainer_stats = self._trainer.train()
self._trainer.train()

def save_model(self) -> None:
self._trainer.model.save_pretrained(self._weights_path)
Expand Down
16 changes: 7 additions & 9 deletions llmtune/inference/lora.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import csv
import os
from os.path import join
from threading import Thread
import csv

from transformers import TextIteratorStreamer
from rich.text import Text
import torch
from datasets import Dataset
from transformers import AutoTokenizer, BitsAndBytesConfig
from peft import AutoPeftModelForCausalLM
import torch

from rich.text import Text
from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

from llmtune.pydantic_models.config_model import Config
from llmtune.utils.save_utils import DirectoryHelper
from llmtune.inference.generics import Inference
from llmtune.pydantic_models.config_model import Config
from llmtune.ui.rich_ui import RichUI
from llmtune.utils.save_utils import DirectoryHelper


# TODO: Add type hints please!
Expand Down Expand Up @@ -89,7 +87,7 @@ def infer_all(self):

try:
result = self.infer_one(prompt)
except:
except Exception:
continue
results.append((prompt, label, result))

Expand Down
9 changes: 4 additions & 5 deletions llmtune/pydantic_models/config_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Literal, Union, List, Dict, Optional
from pydantic import BaseModel, FilePath, validator, Field

from huggingface_hub.utils import validate_repo_id
from typing import List, Literal, Optional, Union

import torch
from pydantic import BaseModel, Field, FilePath, validator


# TODO: Refactor this into multiple files...
HfModelPath = str

class QaConfig(BaseModel):
llm_tests: Optional[List[str]] = Field([], description = "list of tests that needs to be connected")


class DataConfig(BaseModel):
file_type: Literal["json", "csv", "huggingface"] = Field(
Expand Down
14 changes: 6 additions & 8 deletions llmtune/qa/generics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import statistics
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Dict
from typing import Dict, List, Union

import pandas as pd

from llmtune.ui.rich_ui import RichUI
import statistics
from llmtune.qa.qa_tests import *


class LLMQaTest(ABC):
Expand Down Expand Up @@ -32,7 +33,7 @@ def inner_wrapper(wrapped_class):
return inner_wrapper

@classmethod
def create_tests_from_list(cls, test_name: str) -> List[LLMQaTest]:
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
return [cls.create_test(test) for test in test_names]


Expand Down Expand Up @@ -71,10 +72,7 @@ def test_results(self):

def print_test_results(self):
result_dictionary = self.test_results()
column_data = {
key: [value for value in result_dictionary[key]]
for key in result_dictionary
}
column_data = {key: list(result_dictionary[key]) for key in result_dictionary}
mean_values = {key: statistics.mean(column_data[key]) for key in column_data}
median_values = {
key: statistics.median(column_data[key]) for key in column_data
Expand Down
16 changes: 9 additions & 7 deletions llmtune/qa/qa_tests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from llmtune.qa.generics import LLMQaTest
from typing import Union, List, Tuple, Dict
import torch
from transformers import DistilBertModel, DistilBertTokenizer
from typing import List, Union

import nltk
import numpy as np
from rouge_score import rouge_scorer
import torch
from nltk import pos_tag
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk import pos_tag
from llmtune.qa.generics import TestRegistry
from rouge_score import rouge_scorer
from transformers import DistilBertModel, DistilBertTokenizer

from llmtune.qa.generics import LLMQaTest, TestRegistry


model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
Expand Down
4 changes: 2 additions & 2 deletions llmtune/ui/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def finetune_found(weights_path: str):
pass

"""
INFERENCE
INFERENCE
"""

# Lifecycle functions
Expand Down Expand Up @@ -91,7 +91,7 @@ def inference_stream_display(text: Text):
pass

"""
QA
QA
"""

# Lifecycle functions
Expand Down
12 changes: 6 additions & 6 deletions llmtune/ui/rich_ui.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from datasets import Dataset

from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.panel import Panel
from rich.table import Table
from rich.live import Live
from rich.text import Text

from llmtune.ui.generics import UI
from llmtune.utils.rich_print_utils import inject_example_to_rich_layout


console = Console()


Expand Down Expand Up @@ -72,7 +72,7 @@ def during_dataset_creation(message: str, spinner: str):
@staticmethod
def after_dataset_creation(save_dir: str, train: Dataset, test: Dataset):
console.print(f"Dataset Saved at {save_dir}")
console.print(f"Post-Split data size:")
console.print("Post-Split data size:")
console.print(f"Train: {len(train)}")
console.print(f"Test: {len(test)}")

Expand Down Expand Up @@ -122,14 +122,14 @@ def during_finetune():

@staticmethod
def after_finetune():
console.print(f"Finetuning complete!")
console.print("Finetuning complete!")

@staticmethod
def finetune_found(weights_path: str):
console.print(f"Fine-Tuned Model Found at {weights_path}... skipping training")

"""
INFERENCE
INFERENCE
"""

# Lifecycle functions
Expand Down Expand Up @@ -167,7 +167,7 @@ def inference_stream_display(text: Text):
return LiveContext(text)

"""
QA
QA
"""

# Lifecycle functions
Expand Down
9 changes: 3 additions & 6 deletions llmtune/utils/ablation_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import copy
import itertools
from typing import List, Type, Any, Dict, Optional, Union, Tuple
from typing import get_args, get_origin, get_type_hints

import yaml
from typing import Dict, Tuple, Union, get_args, get_origin


# TODO: organize this a little bit. It's a bit of a mess rn.
Expand All @@ -18,11 +15,11 @@ def get_types_from_dict(
source_dict: dict, root="", type_dict={}
) -> Dict[str, Tuple[type, type]]:
for key, val in source_dict.items():
if type(val) is not dict:
if not isinstance(val, dict):
attr = f"{root}.{key}" if root else key
tp = (
(type(val), None)
if type(val) is not list
if not isinstance(val, list)
else (type(val), type(val[0]))
)
type_dict[attr] = tp
Expand Down
4 changes: 2 additions & 2 deletions llmtune/utils/rich_print_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rich.panel import Panel
from rich.layout import Layout
from rich.text import Text
from rich.panel import Panel
from rich.table import Table
from rich.text import Text


def inject_example_to_rich_layout(layout: Layout, layout_name: str, example: dict):
Expand Down
11 changes: 5 additions & 6 deletions llmtune/utils/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@
2. Check if files are present at various experiment stages
"""

import shutil
import hashlib
import os
from os.path import exists
import yaml

import re
import hashlib
from functools import cached_property
from dataclasses import dataclass
from functools import cached_property
from os.path import exists

import yaml
from sqids import Sqids

from llmtune.pydantic_models.config_model import Config


NUM_MD5_DIGITS_FOR_SQIDS = 5 # TODO: maybe move consts to a dedicated folder


Expand Down
Loading