Skip to content

Commit 1071105

Browse files
authored
adding lm-eval test harness (#1371)
Adds explicit testing for lm-eval, although these don't trigger in CI (currently). We should find a faster way to test this --------- Signed-off-by: Peter St. John <[email protected]>
1 parent 6d4e8bb commit 1071105

File tree

5 files changed

+130
-18
lines changed

5 files changed

+130
-18
lines changed

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030

3131

3232
AUTO_MAP = {
33-
"AutoConfig": "llama3_nv.NVLlamaConfig",
34-
"AutoModel": "llama3_nv.NVLlamaModel",
35-
"AutoModelForCausalLM": "llama3_nv.NVLlamaForCausalLM",
36-
"AutoModelForSequenceClassification": "llama3_nv.NVLlamaForSequenceClassification",
37-
"AutoModelForQuestionAnswering": "llama3_nv.NVLlamaForQuestionAnswering",
38-
"AutoModelForTokenClassification": "llama3_nv.NVLlamaForTokenClassification",
33+
"AutoConfig": "modeling_llama_te.NVLlamaConfig",
34+
"AutoModel": "modeling_llama_te.NVLlamaModel",
35+
"AutoModelForCausalLM": "modeling_llama_te.NVLlamaForCausalLM",
36+
"AutoModelForSequenceClassification": "modeling_llama_te.NVLlamaForSequenceClassification",
37+
"AutoModelForQuestionAnswering": "modeling_llama_te.NVLlamaForQuestionAnswering",
38+
"AutoModelForTokenClassification": "modeling_llama_te.NVLlamaForTokenClassification",
3939
}
4040

4141

@@ -191,11 +191,12 @@ def forward(
191191

192192
# This might be slower for BSHD + padding with fused attention backend. But it should be faster for the flash
193193
# attention backend.
194+
self_attn_mask_type = "padding_causal"
194195
if should_pack_inputs:
195196
# Left-side padding is not supported in TE layers, so to make generation work with TE we dynamically convert
196197
# to THD-style inputs in our forward pass, and then convert back to BSHD for the output. This lets the
197198
# entire transformer stack run in THD mode.
198-
assert attention_mask is not None, "Attention mask is required when using BSHD inputs."
199+
assert attention_mask is not None, "Attention mask is required when packing BSHD inputs."
199200
batch_size = hidden_states.size(0)
200201
hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask)
201202
cu_seq_lens_q = cu_seq_lens_k = cu_seqlens
@@ -213,8 +214,10 @@ def forward(
213214
max_length_k = kwargs["max_length_k"]
214215

215216
else:
216-
assert attention_mask is not None, "Attention mask is required when using BSHD inputs."
217-
attention_mask = attention_mask[:, None, None, :] < -1
217+
if attention_mask is not None:
218+
attention_mask = attention_mask[:, None, None, :] < -1
219+
else:
220+
self_attn_mask_type = "causal"
218221
cu_seq_lens_q = cu_seq_lens_k = None
219222
max_length_q = max_length_k = hidden_states.size(1)
220223

@@ -243,6 +246,7 @@ def forward(
243246
hidden_states,
244247
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
245248
rotary_pos_emb=te_rope_emb,
249+
self_attn_mask_type=self_attn_mask_type,
246250
inference_params=past_key_values,
247251
cu_seqlens_q=cu_seq_lens_q,
248252
cu_seqlens_kv=cu_seq_lens_k,

bionemo-recipes/models/llama3/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
lm-eval # For testing
12
torch
23
torchao!=0.14.0
34
transformer_engine[pytorch]
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import os
18+
import shutil
19+
import subprocess
20+
from pathlib import Path
21+
22+
import pytest
23+
from transformers import AutoTokenizer
24+
25+
from modeling_llama_te import AUTO_MAP, NVLlamaConfig, NVLlamaForCausalLM
26+
27+
28+
@pytest.fixture
29+
def model_checkpoint(tmp_path: Path):
30+
tokenizer = AutoTokenizer.from_pretrained("nvidia/Llama-3.1-8B-Instruct-FP8")
31+
config = NVLlamaConfig.from_pretrained(
32+
"nvidia/Llama-3.1-8B-Instruct-FP8", num_hidden_layers=2, attn_input_format="bshd"
33+
)
34+
model = NVLlamaForCausalLM(config)
35+
model.save_pretrained(tmp_path / "checkpoint")
36+
37+
tokenizer = AutoTokenizer.from_pretrained("nucleotide_fast_tokenizer")
38+
tokenizer.save_pretrained(tmp_path / "checkpoint")
39+
40+
# Patch the config
41+
with open(tmp_path / "checkpoint" / "config.json", "r") as f:
42+
config = json.load(f)
43+
44+
config["auto_map"] = AUTO_MAP
45+
46+
with open(tmp_path / "checkpoint" / "config.json", "w") as f:
47+
json.dump(config, f, indent=2, sort_keys=True)
48+
49+
shutil.copy("modeling_llama_te.py", tmp_path / "checkpoint" / "modeling_llama_te.py")
50+
return tmp_path / "checkpoint"
51+
52+
53+
@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping slow lm-eval test in CI.")
54+
def test_lm_eval(model_checkpoint: Path):
55+
# Create a mock model checkpoint
56+
57+
cmd = [
58+
"lm_eval",
59+
"--model",
60+
"hf",
61+
"--model_args",
62+
f"pretrained={model_checkpoint},tokenizer={model_checkpoint}",
63+
"--trust_remote_code",
64+
"--tasks",
65+
"arc_easy", # TODO(BIONEMO-3410): support other tasks that use inference, e.g. coqa
66+
"--device",
67+
"cuda:0",
68+
"--batch_size",
69+
"8",
70+
]
71+
72+
result = subprocess.run(
73+
cmd,
74+
check=False,
75+
text=True,
76+
stdout=subprocess.PIPE,
77+
stderr=subprocess.PIPE,
78+
timeout=240,
79+
)
80+
81+
if result.returncode != 0:
82+
print(f"STDOUT:\n{result.stdout}")
83+
print(f"STDERR:\n{result.stderr}")
84+
pytest.fail(f"Command failed with exit code {result.returncode}")

bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ def test_llama_model_forward_pass(input_text, attn_input_format):
6666
assert len(outputs.hidden_states) == config.num_hidden_layers + 1
6767

6868

69+
def test_llama_model_forward_pass_no_attention_mask():
70+
tokenizer = AutoTokenizer.from_pretrained("nvidia/Llama-3.1-8B-Instruct-FP8")
71+
config = NVLlamaConfig.from_pretrained(
72+
"nvidia/Llama-3.1-8B-Instruct-FP8", num_hidden_layers=2, attn_input_format="bshd"
73+
)
74+
model = NVLlamaForCausalLM(config)
75+
76+
input_text = ["Hello, world!"]
77+
inputs = tokenizer(input_text, return_tensors="pt")
78+
inputs = {k: v.to("cuda") for k, v in inputs.items() if k != "attention_mask"}
79+
model.to("cuda")
80+
with torch.no_grad():
81+
outputs = model(**inputs, output_hidden_states=True)
82+
83+
assert outputs.logits is not None
84+
assert outputs.hidden_states is not None
85+
assert len(outputs.hidden_states) == config.num_hidden_layers + 1
86+
87+
6988
@pytest.mark.parametrize("attn_input_format", ["thd", "bshd"])
7089
def test_llama_model_backward_pass(input_text, attn_input_format):
7190
if attn_input_format == "thd" and torch.cuda.get_device_capability()[0] == 12:

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030

3131

3232
AUTO_MAP = {
33-
"AutoConfig": "llama3_nv.NVLlamaConfig",
34-
"AutoModel": "llama3_nv.NVLlamaModel",
35-
"AutoModelForCausalLM": "llama3_nv.NVLlamaForCausalLM",
36-
"AutoModelForSequenceClassification": "llama3_nv.NVLlamaForSequenceClassification",
37-
"AutoModelForQuestionAnswering": "llama3_nv.NVLlamaForQuestionAnswering",
38-
"AutoModelForTokenClassification": "llama3_nv.NVLlamaForTokenClassification",
33+
"AutoConfig": "modeling_llama_te.NVLlamaConfig",
34+
"AutoModel": "modeling_llama_te.NVLlamaModel",
35+
"AutoModelForCausalLM": "modeling_llama_te.NVLlamaForCausalLM",
36+
"AutoModelForSequenceClassification": "modeling_llama_te.NVLlamaForSequenceClassification",
37+
"AutoModelForQuestionAnswering": "modeling_llama_te.NVLlamaForQuestionAnswering",
38+
"AutoModelForTokenClassification": "modeling_llama_te.NVLlamaForTokenClassification",
3939
}
4040

4141

@@ -191,11 +191,12 @@ def forward(
191191

192192
# This might be slower for BSHD + padding with fused attention backend. But it should be faster for the flash
193193
# attention backend.
194+
self_attn_mask_type = "padding_causal"
194195
if should_pack_inputs:
195196
# Left-side padding is not supported in TE layers, so to make generation work with TE we dynamically convert
196197
# to THD-style inputs in our forward pass, and then convert back to BSHD for the output. This lets the
197198
# entire transformer stack run in THD mode.
198-
assert attention_mask is not None, "Attention mask is required when using BSHD inputs."
199+
assert attention_mask is not None, "Attention mask is required when packing BSHD inputs."
199200
batch_size = hidden_states.size(0)
200201
hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask)
201202
cu_seq_lens_q = cu_seq_lens_k = cu_seqlens
@@ -213,8 +214,10 @@ def forward(
213214
max_length_k = kwargs["max_length_k"]
214215

215216
else:
216-
assert attention_mask is not None, "Attention mask is required when using BSHD inputs."
217-
attention_mask = attention_mask[:, None, None, :] < -1
217+
if attention_mask is not None:
218+
attention_mask = attention_mask[:, None, None, :] < -1
219+
else:
220+
self_attn_mask_type = "causal"
218221
cu_seq_lens_q = cu_seq_lens_k = None
219222
max_length_q = max_length_k = hidden_states.size(1)
220223

@@ -243,6 +246,7 @@ def forward(
243246
hidden_states,
244247
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
245248
rotary_pos_emb=te_rope_emb,
249+
self_attn_mask_type=self_attn_mask_type,
246250
inference_params=past_key_values,
247251
cu_seqlens_q=cu_seq_lens_q,
248252
cu_seqlens_kv=cu_seq_lens_k,

0 commit comments

Comments
 (0)