Skip to content

Commit 8dc6cea

Browse files
authored
Fix #2839: initialize punc_res before conditional to prevent UnboundLocalError (#2840)
* Fix #2839: initialize punc_res before conditional to prevent UnboundLocalError When punc_model is None or not provided, punc_res was only assigned inside the `if self.punc_model is not None:` block. Downstream code paths (punc_segment speaker diarization and sentence_timestamp) accessed punc_res unconditionally, causing UnboundLocalError. Changes: - Initialize punc_res = None before the conditional block - Add punc_res is None guard in punc_segment path with error log - Add punc_res is None guard in sentence_timestamp path with warning log - Add unit tests covering punc_model=None, empty punc_model, and normal flow Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com> * fix: update docstring to only mention punc_model=None case The tests only cover the None case, not the empty string case. Update the module docstring to match actual test coverage. Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com> --------- Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent 8f3314c commit 8dc6cea

2 files changed

Lines changed: 146 additions & 5 deletions

File tree

funasr/auto/auto_model.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,18 @@
2929
from funasr.train_utils.load_pretrained_model import load_pretrained_model
3030
from funasr.utils import export_utils
3131
from funasr.utils import misc
32+
33+
3234
def is_npu_available():
3335
"""检查NPU是否可用。"""
3436
try:
3537
import torch_npu
38+
3639
return torch_npu.npu.is_available()
3740
except ImportError:
3841
return False
3942

43+
4044
def _resolve_ncpu(config, fallback=4):
4145
"""Return a positive integer representing CPU threads from config."""
4246
value = config.get("ncpu", fallback)
@@ -46,6 +50,7 @@ def _resolve_ncpu(config, fallback=4):
4650
value = fallback
4751
return max(value, 1)
4852

53+
4954
try:
5055
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
5156
from funasr.models.campplus.cluster_backend import ClusterBackend
@@ -202,11 +207,13 @@ def build_model(**kwargs):
202207
set_all_random_seed(kwargs.get("seed", 0))
203208

204209
device = kwargs.get("device", "cuda")
205-
if ((device =="cuda" and not torch.cuda.is_available())
210+
if (
211+
(device == "cuda" and not torch.cuda.is_available())
206212
or (device == "xpu" and not torch.xpu.is_available())
207213
or (device == "mps" and not torch.backends.mps.is_available())
208214
or (device == "npu" and not is_npu_available())
209-
or kwargs.get("ngpu", 1) == 0):
215+
or kwargs.get("ngpu", 1) == 0
216+
):
210217
device = "cpu"
211218
kwargs["batch_size"] = 1
212219
kwargs["device"] = device
@@ -573,8 +580,12 @@ def inference_with_vad(self, input, input_len=None, **cfg):
573580
result[k] = []
574581
for t in restored_data[j][k]:
575582
if isinstance(t, dict):
576-
t["start_time"] = (float(t["start_time"]) * 1000 + int(vadsegments[j][0])) / 1000
577-
t["end_time"] = (float(t["end_time"]) * 1000 + int(vadsegments[j][0])) / 1000
583+
t["start_time"] = (
584+
float(t["start_time"]) * 1000 + int(vadsegments[j][0])
585+
) / 1000
586+
t["end_time"] = (
587+
float(t["end_time"]) * 1000 + int(vadsegments[j][0])
588+
) / 1000
578589
else:
579590
t[0] = int(t[0]) + int(vadsegments[j][0])
580591
t[1] = int(t[1]) + int(vadsegments[j][0])
@@ -600,6 +611,7 @@ def inference_with_vad(self, input, input_len=None, **cfg):
600611
return_raw_text = kwargs.get("return_raw_text", False)
601612
# step.3 compute punc model
602613
raw_text = None
614+
punc_res = None
603615
if self.punc_model is not None:
604616
deep_update(self.punc_kwargs, cfg)
605617
punc_res = self.inference(
@@ -645,7 +657,12 @@ def inference_with_vad(self, input, input_len=None, **cfg):
645657
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
646658
can predict timestamp, and speaker diarization relies on timestamps."
647659
)
648-
if kwargs.get("en_post_proc", False):
660+
if punc_res is None:
661+
logging.error(
662+
"Missing punc_model, which is required for punc_segment speaker diarization."
663+
)
664+
sentence_list = []
665+
elif kwargs.get("en_post_proc", False):
649666
sentence_list = timestamp_sentence_en(
650667
punc_res[0]["punc_array"],
651668
result["timestamp"],
@@ -664,6 +681,11 @@ def inference_with_vad(self, input, input_len=None, **cfg):
664681
elif kwargs.get("sentence_timestamp", False):
665682
if not len(result["text"].strip()):
666683
sentence_list = []
684+
elif punc_res is None:
685+
logging.warning(
686+
"punc_model is required for sentence_timestamp, skipping sentence segmentation."
687+
)
688+
sentence_list = []
667689
else:
668690
if kwargs.get("en_post_proc", False):
669691
sentence_list = timestamp_sentence_en(

tests/test_punc_model_none.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Tests for issue #2839: punc_model=None should not cause UnboundLocalError."""
2+
3+
import unittest
4+
from unittest.mock import MagicMock, patch
5+
import numpy as np
6+
7+
8+
class TestPuncModelNone(unittest.TestCase):
9+
"""Test that inference_with_vad works when punc_model is None."""
10+
11+
def _make_auto_model(self, punc_model=None, spk_model=None, spk_mode=None):
12+
"""Create a minimal AutoModel instance with mocked dependencies."""
13+
from funasr.auto.auto_model import AutoModel
14+
15+
am = AutoModel.__new__(AutoModel)
16+
am.model = MagicMock()
17+
am.vad_model = MagicMock()
18+
am.punc_model = punc_model
19+
am.punc_kwargs = {}
20+
am.spk_model = spk_model
21+
am.cb_model = None
22+
am.spk_mode = spk_mode
23+
am.vad_kwargs = {}
24+
am.kwargs = {
25+
"batch_size_s": 300,
26+
"batch_size_threshold_s": 60,
27+
"device": "cpu",
28+
"disable_pbar": True,
29+
"frontend": MagicMock(fs=16000),
30+
"fs": 16000,
31+
}
32+
am._reset_runtime_configs = MagicMock()
33+
return am
34+
35+
def _setup_mocks(self, am, mock_slice, mock_load, mock_prep):
36+
"""Configure standard mocks for a single-segment VAD + ASR flow."""
37+
# VAD returns one segment [0, 16000ms]
38+
vad_result = [{"key": "test_utt", "value": [[0, 16000]]}]
39+
# ASR returns text with timestamps
40+
asr_result = [{"text": "hello world", "timestamp": [[0, 500], [500, 1000]]}]
41+
42+
call_count = [0]
43+
results_seq = [vad_result, asr_result]
44+
45+
def mock_inference(data, input_len=None, model=None, kwargs=None, **cfg):
46+
idx = call_count[0]
47+
call_count[0] += 1
48+
if idx < len(results_seq):
49+
return results_seq[idx]
50+
return [{"text": ""}]
51+
52+
am.inference = MagicMock(side_effect=mock_inference)
53+
mock_prep.return_value = (["test_utt"], [np.zeros(16000, dtype=np.float32)])
54+
mock_load.return_value = np.zeros(16000, dtype=np.float32)
55+
mock_slice.return_value = ([np.zeros(16000, dtype=np.float32)], [16000])
56+
57+
@patch("funasr.auto.auto_model.slice_padding_audio_samples")
58+
@patch("funasr.auto.auto_model.load_audio_text_image_video")
59+
@patch("funasr.auto.auto_model.prepare_data_iterator")
60+
def test_punc_model_none_basic(self, mock_prep, mock_load, mock_slice):
61+
"""Basic inference with punc_model=None should not raise UnboundLocalError."""
62+
am = self._make_auto_model(punc_model=None)
63+
self._setup_mocks(am, mock_slice, mock_load, mock_prep)
64+
65+
results = am.inference_with_vad("dummy_input")
66+
67+
self.assertEqual(len(results), 1)
68+
self.assertEqual(results[0]["text"], "hello world")
69+
self.assertEqual(results[0]["key"], "test_utt")
70+
71+
@patch("funasr.auto.auto_model.slice_padding_audio_samples")
72+
@patch("funasr.auto.auto_model.load_audio_text_image_video")
73+
@patch("funasr.auto.auto_model.prepare_data_iterator")
74+
def test_sentence_timestamp_with_punc_model_none(self, mock_prep, mock_load, mock_slice):
75+
"""sentence_timestamp=True with punc_model=None should not crash."""
76+
am = self._make_auto_model(punc_model=None)
77+
self._setup_mocks(am, mock_slice, mock_load, mock_prep)
78+
79+
# This path previously caused UnboundLocalError on punc_res
80+
results = am.inference_with_vad("dummy_input", sentence_timestamp=True)
81+
82+
self.assertEqual(len(results), 1)
83+
# sentence_info should be empty list since punc_res is unavailable
84+
self.assertEqual(results[0].get("sentence_info"), [])
85+
86+
@patch("funasr.auto.auto_model.slice_padding_audio_samples")
87+
@patch("funasr.auto.auto_model.load_audio_text_image_video")
88+
@patch("funasr.auto.auto_model.prepare_data_iterator")
89+
def test_punc_model_with_value_still_works(self, mock_prep, mock_load, mock_slice):
90+
"""When punc_model is provided, punc_res should still be used normally."""
91+
punc_mock = MagicMock()
92+
am = self._make_auto_model(punc_model=punc_mock)
93+
94+
vad_result = [{"key": "test_utt", "value": [[0, 16000]]}]
95+
asr_result = [{"text": "hello world", "timestamp": [[0, 500], [500, 1000]]}]
96+
punc_result = [{"text": "Hello, world.", "punc_array": [1, 2]}]
97+
98+
call_count = [0]
99+
results_seq = [vad_result, asr_result, punc_result]
100+
101+
def mock_inference(data, input_len=None, model=None, kwargs=None, **cfg):
102+
idx = call_count[0]
103+
call_count[0] += 1
104+
return results_seq[idx]
105+
106+
am.inference = MagicMock(side_effect=mock_inference)
107+
mock_prep.return_value = (["test_utt"], [np.zeros(16000, dtype=np.float32)])
108+
mock_load.return_value = np.zeros(16000, dtype=np.float32)
109+
mock_slice.return_value = ([np.zeros(16000, dtype=np.float32)], [16000])
110+
111+
results = am.inference_with_vad("dummy_input")
112+
113+
self.assertEqual(len(results), 1)
114+
# Text should be updated with punctuated version
115+
self.assertEqual(results[0]["text"], "Hello, world.")
116+
117+
118+
if __name__ == "__main__":
119+
unittest.main()

0 commit comments

Comments
 (0)