@@ -39,12 +39,12 @@ def mocked_import(name, *args, **kwargs):
3939 @pytest .fixture
4040 def expected_vision_acc (self ):
4141 return {
42- "Science" : 0 ,
42+ "Science" : 0.2 ,
4343 "Biology" : 0 ,
44- "Chemistry" : 0 ,
44+ "Chemistry" : 0.3333 ,
4545 "Geography" : 0 ,
4646 "Math" : 0 ,
47- "Physics" : 0 ,
47+ "Physics" : 0.6667 ,
4848 }
4949
5050 @pytest .mark .parametrize (
@@ -212,6 +212,7 @@ def test_eval_recipe_errors_with_qat_quantizer(self, monkeypatch, tmpdir):
212212 runpy .run_path (TUNE_PATH , run_name = "__main__" )
213213
214214 @pytest .mark .integration_test
215+ @gpu_test (gpu_count = 1 )
215216 def test_meta_eval_vision (self , caplog , monkeypatch , tmpdir , expected_vision_acc ):
216217 ckpt = "llama3_2_vision_meta"
217218 ckpt_path = Path (CKPT_MODEL_PATHS [ckpt ])
@@ -230,9 +231,9 @@ def test_meta_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc
230231 checkpointer.model_type=LLAMA3_VISION \
231232 tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
232233 tokenizer.prompt_template=null \
233- limit=1 \
234+ limit=3 \
234235 dtype=bf16 \
235- device=cpu \
236+ device=cuda \
236237 """ .split ()
237238
238239 model_config = llama3_2_vision_test_config ()
@@ -251,6 +252,7 @@ def test_meta_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc
251252 assert math .isclose (float (accuracy ), expected_vision_acc [task_name ])
252253
253254 @pytest .mark .integration_test
255+ @gpu_test (gpu_count = 1 )
254256 def test_hf_eval_vision (self , caplog , monkeypatch , tmpdir , expected_vision_acc ):
255257 ckpt = "llama3_2_vision_hf"
256258 ckpt_path = Path (CKPT_MODEL_PATHS [ckpt ])
@@ -272,9 +274,9 @@ def test_hf_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
272274 checkpointer.model_type=LLAMA3_VISION \
273275 tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
274276 tokenizer.prompt_template=null \
275- limit=1 \
277+ limit=3 \
276278 dtype=bf16 \
277- device=cpu \
279+ device=cuda \
278280 """ .split ()
279281
280282 model_config = llama3_2_vision_test_config ()
0 commit comments