36
36
"--model" ,
37
37
help = "F5-TTS | E2-TTS" ,
38
38
)
39
+ parser .add_argument (
40
+ "-p" ,
41
+ "--ckpt_file" ,
42
+ help = "The Checkpoint .pt" ,
43
+ )
44
+ parser .add_argument (
45
+ "-v" ,
46
+ "--vocab_file" ,
47
+ help = "The vocab .txt" ,
48
+ )
39
49
parser .add_argument (
40
50
"-r" ,
41
51
"--ref_audio" ,
88
98
gen_text = codecs .open (gen_file , "r" , "utf-8" ).read ()
89
99
output_dir = args .output_dir if args .output_dir else config ["output_dir" ]
90
100
model = args .model if args .model else config ["model" ]
101
+ ckpt_file = args .ckpt_file if args .ckpt_file else ""
102
+ vocab_file = args .vocab_file if args .vocab_file else ""
91
103
remove_silence = args .remove_silence if args .remove_silence else config ["remove_silence" ]
92
104
wave_path = Path (output_dir )/ "out.wav"
93
105
spectrogram_path = Path (output_dir )/ "out.png"
125
137
# fix_duration = 27 # None or float (duration in seconds)
126
138
fix_duration = None
127
139
128
- def load_model (repo_name , exp_name , model_cls , model_cfg , ckpt_step ):
129
- ckpt_path = f"ckpts/{ exp_name } /model_{ ckpt_step } .pt" # .pt | .safetensors
130
- if not Path (ckpt_path ).exists ():
131
- ckpt_path = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" ))
132
- vocab_char_map , vocab_size = get_tokenizer ("Emilia_ZH_EN" , "pinyin" )
140
+ def load_model (model_cls , model_cfg , ckpt_path ,file_vocab ):
141
+
142
+ if file_vocab == "" :
143
+ file_vocab = "Emilia_ZH_EN"
144
+ tokenizer = "pinyin"
145
+ else :
146
+ tokenizer = "custom"
147
+
148
+ print ("\n vocab : " ,vocab_file ,tokenizer )
149
+ print ("tokenizer : " ,tokenizer )
150
+ print ("model : " ,ckpt_path ,"\n " )
151
+
152
+ vocab_char_map , vocab_size = get_tokenizer (file_vocab , tokenizer )
133
153
model = CFM (
134
154
transformer = model_cls (
135
155
** model_cfg , text_num_embeds = vocab_size , mel_dim = n_mel_channels
@@ -149,14 +169,12 @@ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
149
169
150
170
return model
151
171
152
-
153
172
# load models
154
173
F5TTS_model_cfg = dict (
155
174
dim = 1024 , depth = 22 , heads = 16 , ff_mult = 2 , text_dim = 512 , conv_layers = 4
156
175
)
157
176
E2TTS_model_cfg = dict (dim = 1024 , depth = 24 , heads = 16 , ff_mult = 4 )
158
177
159
-
160
178
def chunk_text (text , max_chars = 135 ):
161
179
"""
162
180
Splits the input text into chunks, each with a maximum number of characters.
@@ -184,12 +202,29 @@ def chunk_text(text, max_chars=135):
184
202
185
203
return chunks
186
204
205
+ #ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
206
+ #if not Path(ckpt_path).exists():
207
+ #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
187
208
188
- def infer_batch (ref_audio , ref_text , gen_text_batches , model , remove_silence , cross_fade_duration = 0.15 ):
209
+ def infer_batch (ref_audio , ref_text , gen_text_batches , model ,ckpt_file , file_vocab , remove_silence , cross_fade_duration = 0.15 ):
189
210
if model == "F5-TTS" :
190
- ema_model = load_model (model , "F5TTS_Base" , DiT , F5TTS_model_cfg , 1200000 )
211
+
212
+ if ckpt_file == "" :
213
+ repo_name = "F5-TTS"
214
+ exp_name = "F5TTS_Base"
215
+ ckpt_step = 1200000
216
+ ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" ))
217
+
218
+ ema_model = load_model (DiT , F5TTS_model_cfg , ckpt_file ,file_vocab )
219
+
191
220
elif model == "E2-TTS" :
192
- ema_model = load_model (model , "E2TTS_Base" , UNetT , E2TTS_model_cfg , 1200000 )
221
+ if ckpt_file == "" :
222
+ repo_name = "E2-TTS"
223
+ exp_name = "E2TTS_Base"
224
+ ckpt_step = 1200000
225
+ ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" ))
226
+
227
+ ema_model = load_model (UNetT , E2TTS_model_cfg , ckpt_file ,file_vocab )
193
228
194
229
audio , sr = ref_audio
195
230
if audio .shape [0 ] > 1 :
@@ -325,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
325
360
print ("Using custom reference text..." )
326
361
return ref_audio , ref_text
327
362
328
- def infer (ref_audio , ref_text , gen_text , model , remove_silence , cross_fade_duration = 0.15 ):
363
+ def infer (ref_audio , ref_text , gen_text , model ,ckpt_file , file_vocab , remove_silence , cross_fade_duration = 0.15 ):
329
364
print (gen_text )
330
365
# Add the functionality to ensure it ends with ". "
331
366
if not ref_text .endswith (". " ) and not ref_text .endswith ("。" ):
@@ -343,10 +378,10 @@ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_durat
343
378
print (f'gen_text { i } ' , gen_text )
344
379
345
380
print (f"Generating audio using { model } in { len (gen_text_batches )} batches, loading models..." )
346
- return infer_batch ((audio , sr ), ref_text , gen_text_batches , model , remove_silence , cross_fade_duration )
381
+ return infer_batch ((audio , sr ), ref_text , gen_text_batches , model ,ckpt_file , file_vocab , remove_silence , cross_fade_duration )
347
382
348
383
349
- def process (ref_audio , ref_text , text_gen , model , remove_silence ):
384
+ def process (ref_audio , ref_text , text_gen , model ,ckpt_file , file_vocab , remove_silence ):
350
385
main_voice = {"ref_audio" :ref_audio , "ref_text" :ref_text }
351
386
if "voices" not in config :
352
387
voices = {"main" : main_voice }
@@ -371,7 +406,7 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
371
406
ref_audio = voices [voice ]['ref_audio' ]
372
407
ref_text = voices [voice ]['ref_text' ]
373
408
print (f"Voice: { voice } " )
374
- audio , spectragram = infer (ref_audio , ref_text , gen_text , model , remove_silence )
409
+ audio , spectragram = infer (ref_audio , ref_text , gen_text , model ,ckpt_file , file_vocab , remove_silence )
375
410
generated_audio_segments .append (audio )
376
411
377
412
if generated_audio_segments :
@@ -389,4 +424,5 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
389
424
aseg .export (f .name , format = "wav" )
390
425
print (f .name )
391
426
392
- process (ref_audio , ref_text , gen_text , model , remove_silence )
427
+
428
+ process (ref_audio , ref_text , gen_text , model ,ckpt_file ,vocab_file , remove_silence )
0 commit comments