Skip to content

Commit 84cb6e5

Browse files
authored
Merge pull request #173 from lpscr/main
add new args in interface-cli.py for pass model and vocab
2 parents 925ce4b + 60f1b31 commit 84cb6e5

File tree

2 files changed

+54
-15
lines changed

2 files changed

+54
-15
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ Currently support 30s for a single generation, which is the **TOTAL** length of
8686

8787
Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
8888

89+
for change model use --ckpt_file to specify the model you want to load,
90+
for change vocab.txt use --vocab_file to provide your vocab.txt file.
91+
8992
```bash
9093
python inference-cli.py \
9194
--model "F5-TTS" \

inference-cli.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@
3636
"--model",
3737
help="F5-TTS | E2-TTS",
3838
)
39+
parser.add_argument(
40+
"-p",
41+
"--ckpt_file",
42+
help="The Checkpoint .pt",
43+
)
44+
parser.add_argument(
45+
"-v",
46+
"--vocab_file",
47+
help="The vocab .txt",
48+
)
3949
parser.add_argument(
4050
"-r",
4151
"--ref_audio",
@@ -88,6 +98,8 @@
8898
gen_text = codecs.open(gen_file, "r", "utf-8").read()
8999
output_dir = args.output_dir if args.output_dir else config["output_dir"]
90100
model = args.model if args.model else config["model"]
101+
ckpt_file = args.ckpt_file if args.ckpt_file else ""
102+
vocab_file = args.vocab_file if args.vocab_file else ""
91103
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
92104
wave_path = Path(output_dir)/"out.wav"
93105
spectrogram_path = Path(output_dir)/"out.png"
@@ -125,11 +137,19 @@
125137
# fix_duration = 27 # None or float (duration in seconds)
126138
fix_duration = None
127139

128-
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
129-
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
130-
if not Path(ckpt_path).exists():
131-
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
132-
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
140+
def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
141+
142+
if file_vocab=="":
143+
file_vocab="Emilia_ZH_EN"
144+
tokenizer="pinyin"
145+
else:
146+
tokenizer="custom"
147+
148+
print("\nvocab : ",vocab_file,tokenizer)
149+
print("tokenizer : ",tokenizer)
150+
print("model : ",ckpt_path,"\n")
151+
152+
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
133153
model = CFM(
134154
transformer=model_cls(
135155
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
@@ -149,14 +169,12 @@ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
149169

150170
return model
151171

152-
153172
# load models
154173
F5TTS_model_cfg = dict(
155174
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
156175
)
157176
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
158177

159-
160178
def chunk_text(text, max_chars=135):
161179
"""
162180
Splits the input text into chunks, each with a maximum number of characters.
@@ -184,12 +202,29 @@ def chunk_text(text, max_chars=135):
184202

185203
return chunks
186204

205+
#ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
206+
#if not Path(ckpt_path).exists():
207+
#ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
187208

188-
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
209+
def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
189210
if model == "F5-TTS":
190-
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
211+
212+
if ckpt_file == "":
213+
repo_name= "F5-TTS"
214+
exp_name = "F5TTS_Base"
215+
ckpt_step= 1200000
216+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
217+
218+
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
219+
191220
elif model == "E2-TTS":
192-
ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
221+
if ckpt_file == "":
222+
repo_name= "E2-TTS"
223+
exp_name = "E2TTS_Base"
224+
ckpt_step= 1200000
225+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
226+
227+
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
193228

194229
audio, sr = ref_audio
195230
if audio.shape[0] > 1:
@@ -325,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
325360
print("Using custom reference text...")
326361
return ref_audio, ref_text
327362

328-
def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
363+
def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
329364
print(gen_text)
330365
# Add the functionality to ensure it ends with ". "
331366
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
@@ -343,10 +378,10 @@ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_durat
343378
print(f'gen_text {i}', gen_text)
344379

345380
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
346-
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
381+
return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
347382

348383

349-
def process(ref_audio, ref_text, text_gen, model, remove_silence):
384+
def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
350385
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
351386
if "voices" not in config:
352387
voices = {"main": main_voice}
@@ -371,7 +406,7 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
371406
ref_audio = voices[voice]['ref_audio']
372407
ref_text = voices[voice]['ref_text']
373408
print(f"Voice: {voice}")
374-
audio, spectragram = infer(ref_audio, ref_text, gen_text, model, remove_silence)
409+
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
375410
generated_audio_segments.append(audio)
376411

377412
if generated_audio_segments:
@@ -389,4 +424,5 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
389424
aseg.export(f.name, format="wav")
390425
print(f.name)
391426

392-
process(ref_audio, ref_text, gen_text, model, remove_silence)
427+
428+
process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)

0 commit comments

Comments
 (0)