|
22 | 22 | import tempfile
|
23 | 23 | import unittest
|
24 | 24 | import warnings
|
| 25 | +from pathlib import Path |
25 | 26 |
|
26 | 27 | import numpy as np
|
27 | 28 | import pytest
|
@@ -4995,6 +4996,27 @@ def test_custom_generate_requires_trust_remote_code(self):
|
4995 | 4996 | with self.assertRaises(ValueError):
|
4996 | 4997 | model.generate(**model_inputs, custom_generate="transformers-community/custom_generate_example")
|
4997 | 4998 |
|
| 4999 | + def test_custom_generate_local_directory(self): |
| 5000 | + """Tests that custom_generate works with local directories containing importable relative modules""" |
| 5001 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 5002 | + custom_generate_dir = Path(tmp_dir) / "custom_generate" |
| 5003 | + custom_generate_dir.mkdir() |
| 5004 | + with open(custom_generate_dir / "generate.py", "w") as f: |
| 5005 | + f.write("from .helper import ret_success\ndef generate(*args, **kwargs):\n return ret_success()\n") |
| 5006 | + with open(custom_generate_dir / "helper.py", "w") as f: |
| 5007 | + f.write('def ret_success():\n return "success"\n') |
| 5008 | + model = AutoModelForCausalLM.from_pretrained( |
| 5009 | + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto" |
| 5010 | + ) |
| 5011 | + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") |
| 5012 | + model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device) |
| 5013 | + value = model.generate( |
| 5014 | + **model_inputs, |
| 5015 | + custom_generate=str(tmp_dir), |
| 5016 | + trust_remote_code=True, |
| 5017 | + ) |
| 5018 | + assert value == "success" |
| 5019 | + |
4998 | 5020 |
|
4999 | 5021 | @require_torch
|
5000 | 5022 | class TokenHealingTestCase(unittest.TestCase):
|
|
0 commit comments