Skip to content

Commit 98ae830

Browse files
kartikaykThomas Capelle
authored andcommitted
Inference (meta-pytorch#619)
1 parent 0217bfa commit 98ae830

24 files changed

+483
-907
lines changed

docs/source/api_ref_utilities.rst

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,6 @@ Data
6767

6868
.. _gen_label:
6969

70-
Generation
71-
----------
72-
73-
.. autosummary::
74-
:toctree: generated/
75-
:nosignatures:
76-
77-
generation.GenerationUtils
78-
generation.generate_from_prompt
79-
8070

8171
Miscellaneous
8272
-------------

recipes/alpaca_generate.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

recipes/configs/alpaca_generate.yaml

Lines changed: 0 additions & 22 deletions
This file was deleted.

recipes/configs/generate.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
# Model arguments
3+
model:
4+
_component_: torchtune.models.llama2.llama2_13b
5+
6+
checkpointer:
7+
_component_: torchtune.utils.FullModelHFCheckpointer
8+
checkpoint_dir: /tmp/Llama-2-13b-hf/
9+
checkpoint_files: [
10+
pytorch_model-00001-of-00003.bin,
11+
pytorch_model-00002-of-00003.bin,
12+
pytorch_model-00003-of-00003.bin
13+
]
14+
output_dir: /tmp/Llama-2-13b-hf/
15+
model_type: LLAMA2
16+
17+
device: cuda
18+
dtype: bf16
19+
20+
seed: 1234
21+
22+
# Tokenizer arguments
23+
tokenizer:
24+
_component_: torchtune.models.llama2.llama2_tokenizer
25+
path: /tmp/Llama-2-13b-hf/tokenizer.model
26+
27+
# Generation arguments; defaults taken from gpt-fast
28+
prompt: "Hello, my name is"
29+
max_new_tokens: 300
30+
temperature: 0.8
31+
top_k: 300

recipes/generate.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import sys
7+
import time
8+
from typing import Any, Dict
9+
10+
import torch
11+
from omegaconf import DictConfig
12+
13+
from torch import nn
14+
15+
from torchtune import config, utils
16+
17+
logger = utils.get_logger("DEBUG")
18+
19+
20+
class InferenceRecipe:
21+
"""
22+
Recipe for generating tokens from a dense Transformer-based LLM.
23+
24+
Currently this recipe support single-GPU generation only. Speculative
25+
decoding is not supported.
26+
"""
27+
28+
def __init__(self, cfg: DictConfig) -> None:
29+
self._device = utils.get_device(device=cfg.device)
30+
self._dtype = utils.get_dtype(dtype=cfg.dtype)
31+
32+
utils.set_seed(seed=cfg.seed)
33+
34+
def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
35+
checkpointer = config.instantiate(checkpointer_cfg)
36+
checkpoint_dict = checkpointer.load_checkpoint()
37+
return checkpoint_dict
38+
39+
def setup(self, cfg: DictConfig) -> None:
40+
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
41+
self._model = self._setup_model(
42+
model_cfg=cfg.model,
43+
model_state_dict=ckpt_dict[utils.MODEL_KEY],
44+
)
45+
self._tokenizer = config.instantiate(cfg.tokenizer)
46+
47+
def _setup_model(
48+
self,
49+
model_cfg: DictConfig,
50+
model_state_dict: Dict[str, Any],
51+
) -> nn.Module:
52+
with utils.set_default_dtype(self._dtype), self._device:
53+
model = config.instantiate(model_cfg)
54+
55+
model.load_state_dict(model_state_dict)
56+
57+
# Validate model was loaded in with the expected dtype.
58+
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
59+
logger.info(f"Model is initialized with precision {self._dtype}.")
60+
61+
# Ensure the cache is setup on the right device
62+
with self._device:
63+
model.setup_caches(max_batch_size=1, dtype=self._dtype)
64+
65+
return model
66+
67+
@torch.no_grad()
68+
def generate(self, cfg: DictConfig):
69+
tokens = self._tokenizer.encode(cfg.prompt, add_bos=True, add_eos=False)
70+
prompt = torch.tensor(tokens, dtype=torch.int, device=self._device)
71+
72+
t0 = time.perf_counter()
73+
generated_tokens = utils.generate(
74+
model=self._model,
75+
prompt=prompt,
76+
max_generated_tokens=cfg.max_new_tokens,
77+
temperature=cfg.temperature,
78+
top_k=cfg.top_k,
79+
eos_id=self._tokenizer.eos_id,
80+
)
81+
t = time.perf_counter() - t0
82+
83+
logger.info(self._tokenizer.decode(generated_tokens))
84+
85+
tokens_generated = len(generated_tokens) - prompt.size(0)
86+
tokens_sec = tokens_generated / t
87+
logger.info(
88+
f"Time for inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
89+
)
90+
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
91+
92+
93+
@config.parse
94+
def main(cfg: DictConfig) -> None:
95+
recipe = InferenceRecipe(cfg=cfg)
96+
recipe.setup(cfg=cfg)
97+
recipe.generate(cfg=cfg)
98+
99+
100+
if __name__ == "__main__":
101+
sys.exit(main())

tests/recipes/test_alpaca_generate.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

tests/recipes/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def llama2_test_config(max_batch_size: Optional[int] = None) -> List[str]:
4848
"model.max_seq_len=2048",
4949
"model.norm_eps=1e-5",
5050
"model.num_kv_heads=8",
51-
f"model.max_batch_size={max_batch_size if max_batch_size else 'null'}",
5251
]
5352

5453

@@ -75,7 +74,6 @@ def lora_llama2_test_config(
7574
"model.max_seq_len=2048",
7675
"model.norm_eps=1e-5",
7776
"model.num_kv_heads=8",
78-
f"model.max_batch_size={max_batch_size if max_batch_size else 'null'}",
7977
f"model.lora_rank={lora_rank}",
8078
f"model.lora_alpha={lora_alpha}",
8179
"model.lora_dropout=0.0",

0 commit comments

Comments
 (0)