Skip to content

Commit 182b0f0

Browse files
authored
Merge pull request #149 from lpscr/main
fix problem error about miss parametre in finetune-cli.py
2 parents 852fb32 + 66062f9 commit 182b0f0

File tree

2 files changed

+84
-38
lines changed

2 files changed

+84
-38
lines changed

finetune-cli.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def parse_args():
2828
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
2929
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
3030
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
31+
parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
3132

3233
return parser.parse_args()
3334

@@ -42,17 +43,21 @@ def main():
4243
wandb_resume_id = None
4344
model_cls = DiT
4445
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
45-
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
46+
if args.finetune:
47+
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
4648
elif args.exp_name == "E2TTS_Base":
4749
wandb_resume_id = None
4850
model_cls = UNetT
4951
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50-
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
52+
if args.finetune:
53+
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
54+
55+
if args.finetune:
56+
path_ckpt = os.path.join("ckpts",args.dataset_name)
57+
if os.path.isdir(path_ckpt)==False:
58+
os.makedirs(path_ckpt,exist_ok=True)
59+
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
5160

52-
path_ckpt = os.path.join("ckpts",args.dataset_name)
53-
if os.path.isdir(path_ckpt)==False:
54-
os.makedirs(path_ckpt,exist_ok=True)
55-
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
5661
checkpoint_path=os.path.join("ckpts",args.dataset_name)
5762

5863
# Use the dataset_name provided in the command line

finetune_gradio.py

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,19 @@
99
import librosa
1010
import numpy as np
1111
from scipy.io import wavfile
12-
from tqdm import tqdm
1312
import shutil
1413
import time
1514

1615
import json
17-
from datasets import Dataset
1816
from model.utils import convert_char_to_pinyin
1917
import signal
2018
import psutil
2119
import platform
2220
import subprocess
2321
from datasets.arrow_writer import ArrowWriter
24-
from datasets import load_dataset, load_from_disk
2522

2623
import json
2724

28-
29-
3025
training_process = None
3126
system = platform.system()
3227
python_executable = sys.executable or "python"
@@ -265,16 +260,28 @@ def start_training(dataset_name="",
265260
finetune=True,
266261
):
267262

263+
268264
global training_process
269265

266+
path_project = os.path.join(path_data, dataset_name + "_pinyin")
267+
268+
if os.path.isdir(path_project)==False:
269+
yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
270+
return
271+
272+
file_raw = os.path.join(path_project,"raw.arrow")
273+
if os.path.isfile(file_raw)==False:
274+
yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
275+
return
276+
270277
# Check if a training process is already running
271278
if training_process is not None:
272279
return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
273280

274281
yield "start train",gr.update(interactive=False),gr.update(interactive=False)
275282

276283
# Command to run the training script with the specified arguments
277-
cmd = f"{python_executable} finetune-cli.py --exp_name {exp_name} " \
284+
cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \
278285
f"--learning_rate {learning_rate} " \
279286
f"--batch_size_per_gpu {batch_size_per_gpu} " \
280287
f"--batch_size_type {batch_size_type} " \
@@ -346,6 +353,8 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
346353
path_project_wavs = os.path.join(path_project,"wavs")
347354
file_metadata = os.path.join(path_project,"metadata.csv")
348355

356+
if audio_files is None:return "You need to load an audio file."
357+
349358
if os.path.isdir(path_project_wavs):
350359
shutil.rmtree(path_project_wavs)
351360

@@ -356,16 +365,17 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
356365

357366
if user:
358367
file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
368+
if file_audios==[]:return "No audio file was found in the dataset."
359369
else:
360370
file_audios = audio_files
361-
362-
print([file_audios])
371+
363372

364373
alpha = 0.5
365374
_max = 1.0
366375
slicer = Slicer(24000)
367376

368377
num = 0
378+
error_num = 0
369379
data=""
370380
for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
371381

@@ -381,18 +391,26 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
381391
if(tmp_max>1):chunk/=tmp_max
382392
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
383393
wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
394+
395+
try:
396+
text=transcribe(file_segment,language)
397+
text = text.lower().strip().replace('"',"")
384398

385-
text=transcribe(file_segment,language)
386-
text = text.lower().strip().replace('"',"")
399+
data+= f"{name_segment}|{text}\n"
387400

388-
data+= f"{name_segment}|{text}\n"
401+
num+=1
402+
except:
403+
error_num +=1
389404

390-
num+=1
391-
392405
with open(file_metadata,"w",encoding="utf-8") as f:
393406
f.write(data)
394-
395-
return f"transcribe complete samples : {num} in path {path_project_wavs}"
407+
408+
if error_num!=[]:
409+
error_text=f"\nerror files : {error_num}"
410+
else:
411+
error_text=""
412+
413+
return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
396414

397415
def format_seconds_to_hms(seconds):
398416
hours = int(seconds / 3600)
@@ -408,6 +426,8 @@ def create_metadata(name_project,progress=gr.Progress()):
408426
file_raw = os.path.join(path_project,"raw.arrow")
409427
file_duration = os.path.join(path_project,"duration.json")
410428
file_vocab = os.path.join(path_project,"vocab.txt")
429+
430+
if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
411431

412432
with open(file_metadata,"r",encoding="utf-8") as f:
413433
data=f.read()
@@ -419,11 +439,18 @@ def create_metadata(name_project,progress=gr.Progress()):
419439
count=data.split("\n")
420440
lenght=0
421441
result=[]
442+
error_files=[]
422443
for line in progress.tqdm(data.split("\n"),total=count):
423444
sp_line=line.split("|")
424445
if len(sp_line)!=2:continue
425-
name_audio,text = sp_line[:2]
446+
name_audio,text = sp_line[:2]
447+
426448
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
449+
450+
if os.path.isfile(file_audio)==False:
451+
error_files.append(file_audio)
452+
continue
453+
427454
duraction = get_audio_duration(file_audio)
428455
if duraction<2 and duraction>15:continue
429456
if len(text)<4:continue
@@ -439,6 +466,10 @@ def create_metadata(name_project,progress=gr.Progress()):
439466

440467
lenght+=duraction
441468

469+
if duration_list==[]:
470+
error_files_text="\n".join(error_files)
471+
return f"Error: No audio files found in the specified path : \n{error_files_text}"
472+
442473
min_second = round(min(duration_list),2)
443474
max_second = round(max(duration_list),2)
444475

@@ -450,9 +481,15 @@ def create_metadata(name_project,progress=gr.Progress()):
450481
json.dump({"duration": duration_list}, f, ensure_ascii=False)
451482

452483
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
484+
if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
453485
shutil.copy2(file_vocab_finetune, file_vocab)
454-
455-
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n"
486+
487+
if error_files!=[]:
488+
error_text="error files\n" + "\n".join(error_files)
489+
else:
490+
error_text=""
491+
492+
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
456493

457494
def check_user(value):
458495
return gr.update(visible=not value),gr.update(visible=value)
@@ -466,15 +503,19 @@ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_w
466503
data = json.load(file)
467504

468505
duration_list = data['duration']
506+
469507
samples = len(duration_list)
470508

471-
gpu_properties = torch.cuda.get_device_properties(0)
472-
total_memory = gpu_properties.total_memory / (1024 ** 3)
509+
if torch.cuda.is_available():
510+
gpu_properties = torch.cuda.get_device_properties(0)
511+
total_memory = gpu_properties.total_memory / (1024 ** 3)
512+
elif torch.backends.mps.is_available():
513+
total_memory = psutil.virtual_memory().available / (1024 ** 3)
473514

474515
if batch_size_type=="frame":
475516
batch = int(total_memory * 0.5)
476517
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
477-
batch_size_per_gpu = int(36800 / batch )
518+
batch_size_per_gpu = int(38400 / batch )
478519
else:
479520
batch_size_per_gpu = int(total_memory / 8)
480521
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
@@ -509,13 +550,12 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
509550
if ema_model_state_dict is not None:
510551
new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
511552
torch.save(new_checkpoint, new_checkpoint_path)
512-
print(f"New checkpoint saved at: {new_checkpoint_path}")
553+
return f"New checkpoint saved at: {new_checkpoint_path}"
513554
else:
514-
print("No 'ema_model_state_dict' found in the checkpoint.")
555+
return "No 'ema_model_state_dict' found in the checkpoint."
515556

516557
except Exception as e:
517-
print(f"An error occurred: {e}")
518-
558+
return f"An error occurred: {e}"
519559

520560
def vocab_check(project_name):
521561
name_project = project_name + "_pinyin"
@@ -524,12 +564,17 @@ def vocab_check(project_name):
524564
file_metadata = os.path.join(path_project, "metadata.csv")
525565

526566
file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
567+
if os.path.isfile(file_vocab)==False:
568+
return f"the file {file_vocab} not found !"
527569

528570
with open(file_vocab,"r",encoding="utf-8") as f:
529571
data=f.read()
530572

531573
vocab = data.split("\n")
532574

575+
if os.path.isfile(file_metadata)==False:
576+
return f"the file {file_metadata} not found !"
577+
533578
with open(file_metadata,"r",encoding="utf-8") as f:
534579
data=f.read()
535580

@@ -548,6 +593,7 @@ def vocab_check(project_name):
548593

549594
if miss_symbols==[]:info ="You can train using your language !"
550595
else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
596+
551597
return info
552598

553599

@@ -652,8 +698,9 @@ def vocab_check(project_name):
652698
with gr.TabItem("reduse checkpoint"):
653699
txt_path_checkpoint = gr.Text(label="path checkpoint :")
654700
txt_path_checkpoint_small = gr.Text(label="path output :")
701+
txt_info_reduse = gr.Text(label="info",value="")
655702
reduse_button = gr.Button("reduse")
656-
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small])
703+
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
657704

658705
with gr.TabItem("vocab check experiment"):
659706
check_button = gr.Button("check vocab")
@@ -680,10 +727,4 @@ def main(port, host, share, api):
680727
)
681728

682729
if __name__ == "__main__":
683-
name="my_speak"
684-
685-
#create_data_project(name)
686-
#transcribe_all(name)
687-
#create_metadata(name)
688-
689730
main()

0 commit comments

Comments
 (0)