Skip to content

Commit bb8e062

Browse files
authored
refactor: コア読み込みの条件分岐を共通化 (#1651)
1 parent c6e519f commit bb8e062

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

voicevox_engine/core/core_wrapper.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -384,45 +384,54 @@ def _load_core_version_0_12_or_later(core_path: Path) -> CDLL:
384384
try:
385385
return _load_core_dll(core_path)
386386
except OSError as e:
387-
raise RuntimeError(f"コアの読み込みに失敗しました:{e}") from e
387+
msg = f"利用可能なコアがありましたが、読み込みに失敗しました:{e}"
388+
raise RuntimeError(msg) from e
388389

389390

390391
def _load_core_version_earlier_than_0_12(core_dir: Path, use_gpu: bool) -> CDLL:
391392
"""v0.12以前のコア共有ライブラリを読み込む。"""
392393
model_type = _check_core_type(core_dir)
393394
if model_type is None:
394395
raise RuntimeError("コアが見つかりません")
396+
397+
# 最新の読み込みエラーを記録する
398+
latest_e: OSError | None = None
399+
400+
# GPU 版を読み込む
395401
if use_gpu or model_type == "onnxruntime":
396-
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.CUDA)
397-
if core_name:
398-
try:
399-
return _load_core_dll(core_dir / core_name)
400-
except OSError:
401-
pass
402-
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.DIRECT_ML)
403-
if core_name:
404-
try:
405-
return _load_core_dll(core_dir / core_name)
406-
except OSError:
407-
pass
402+
for gpu_type in [GPUType.CUDA, GPUType.DIRECT_ML]:
403+
core_name = _get_suitable_core_name(model_type, gpu_type)
404+
if core_name:
405+
try:
406+
return _load_core_dll(core_dir / core_name)
407+
except OSError as e:
408+
latest_e = e
409+
410+
# CPU 版を読み込む
411+
# NOTE: GPU 版の読み込みが全て失敗した場合は CPU 版へフォールバックする
408412
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.NONE)
409413
if core_name:
410414
try:
411415
return _load_core_dll(core_dir / core_name)
412416
except OSError as e:
413-
_e = e
414-
if model_type == "libtorch":
415-
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.CUDA)
416-
if core_name:
417-
try:
418-
return _load_core_dll(core_dir / core_name)
419-
except OSError as e:
420-
_e = e
421-
raise RuntimeError(f"コアの読み込みに失敗しました:{_e}") from _e
417+
latest_e = e
418+
419+
# libtorch CUDA 版を CPU モードで読み込む
420+
# NOTE: libtorch は CUDA 版 のみが存在するため、CUDA 版の CPU モードへフォールバックする
421+
if model_type == "libtorch":
422+
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.CUDA)
423+
if core_name:
424+
try:
425+
return _load_core_dll(core_dir / core_name)
426+
except OSError as e:
427+
latest_e = e
428+
429+
if latest_e is not None:
430+
msg = f"利用可能なコアがありましたが、読み込みに失敗しました:{latest_e}"
431+
raise RuntimeError(msg)
422432
else:
423-
raise RuntimeError(
424-
f"このコンピュータのアーキテクチャ {platform.machine()} で利用可能なコアがありません"
425-
)
433+
msg = f"このコンピュータのアーキテクチャ {platform.machine()} で利用可能なコアがありません。"
434+
raise RuntimeError(msg)
426435

427436

428437
_C_TYPE = (

0 commit comments

Comments
 (0)