Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
eb4f7f4
Implement support for AutoGPTQ for loading GPTQ quantized models.
LaaZa Apr 30, 2023
bb1b1b1
Disable cpu support for now. Something in the inference path assumes …
LaaZa Apr 30, 2023
0fd4857
Implement offloading and splitting between multiple devices.
LaaZa Apr 30, 2023
b173274
Check for quantize_config.json and set wbits and groupsize according …
LaaZa May 1, 2023
4f156e4
More robust quantize config parsing and setting for act-order or desc…
LaaZa May 1, 2023
728a207
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 2, 2023
ba4f082
Add support for Triton LLaMA speed optimisations. And re-enable cpu s…
LaaZa May 2, 2023
f5651bc
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 2, 2023
57699be
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 3, 2023
cd8df64
Moved import to reliably load when autogptq is enabled from the inter…
LaaZa May 3, 2023
fe9f236
Ignore pytorch .bins if no well named models found.
LaaZa May 3, 2023
af1503b
Add compatibility mode for models using old quantization.
LaaZa May 4, 2023
4bd06db
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 4, 2023
22dd3bd
Switched messages to logging.
LaaZa May 4, 2023
bed4653
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 4, 2023
b815222
Support --checkpoint
LaaZa May 4, 2023
faf5db3
Class type detection for tokenizer
oobabooga May 4, 2023
e032b40
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 5, 2023
08464cd
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 7, 2023
5f4b817
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 9, 2023
e65ff96
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 9, 2023
a9b5703
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 11, 2023
b176477
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 12, 2023
57a4808
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 12, 2023
df1e100
Apply cpu memory fix to AutoGPTQ_loader
LaaZa May 12, 2023
d212d60
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 13, 2023
78e56de
Update for AutoGPTQ
LaaZa May 14, 2023
4ee76e6
Merge remote-tracking branch 'origin/AutoGPTQ' into AutoGPTQ
LaaZa May 14, 2023
96c2a6a
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 14, 2023
7378332
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 15, 2023
74d68a4
Support .pt models
LaaZa May 15, 2023
b44fd4c
Merge remote-tracking branch 'origin/AutoGPTQ' into AutoGPTQ
LaaZa May 15, 2023
a3f6ec9
Merge branch 'oobabooga:main' into AutoGPTQ
LaaZa May 16, 2023
d26ed82
Fix the model search with the .pt change
LaaZa May 16, 2023
f0ef6e5
Merge branch 'main' into LaaZa-AutoGPTQ
oobabooga Jun 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions modules/AutoGPTQ_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from auto_gptq import AutoGPTQForCausalLM
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

import modules.shared as shared
from modules.logging_colors import logger
Expand All @@ -10,35 +10,45 @@
def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = None
use_safetensors = False

# Find the model checkpoint
for ext in ['.safetensors', '.pt', '.bin']:
found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0:
if len(found) > 1:
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')

pt_path = found[-1]
if ext == '.safetensors':
use_safetensors = True

break
if shared.args.checkpoint:
pt_path = Path(shared.args.checkpoint)
else:
for ext in ['.safetensors', '.pt', '.bin']:
found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0:
if len(found) > 1:
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')

pt_path = found[-1]
break

if pt_path is None:
logger.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
return

use_safetensors = pt_path.suffix == '.safetensors'
if not (path_to_model / "quantize_config.json").exists():
quantize_config = BaseQuantizeConfig(
bits=bits if (bits := shared.args.wbits) > 0 else 4,
group_size=gs if (gs := shared.args.groupsize) > 0 else -1,
desc_act=shared.args.desc_act
)
else:
quantize_config = None

# Define the params for AutoGPTQForCausalLM.from_quantized
params = {
'model_basename': pt_path.stem,
'device': "cuda:0" if not shared.args.cpu else "cpu",
'use_triton': shared.args.triton,
'use_safetensors': use_safetensors,
'trust_remote_code': shared.args.trust_remote_code,
'max_memory': get_max_memory_dict()
'max_memory': get_max_memory_dict(),
'quantize_config': quantize_config
}

logger.warning(f"The AutoGPTQ params are: {params}")
logger.info(f"The AutoGPTQ params are: {params}")
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
return model
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def str2bool(v):
# AutoGPTQ
parser.add_argument('--autogptq', action='store_true', help='Use AutoGPTQ for loading quantized models instead of the internal GPTQ loader.')
parser.add_argument('--triton', action='store_true', help='Use triton.')
parser.add_argument('--desc_act', action='store_true', help='For models that don\'t have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.')

# FlexGen
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
Expand Down