9
9
import librosa
10
10
import numpy as np
11
11
from scipy .io import wavfile
12
- from tqdm import tqdm
13
12
import shutil
14
13
import time
15
14
16
15
import json
17
- from datasets import Dataset
18
16
from model .utils import convert_char_to_pinyin
19
17
import signal
20
18
import psutil
21
19
import platform
22
20
import subprocess
23
21
from datasets .arrow_writer import ArrowWriter
24
- from datasets import load_dataset , load_from_disk
25
22
26
23
import json
27
24
28
-
29
-
30
25
training_process = None
31
26
system = platform .system ()
32
27
python_executable = sys .executable or "python"
@@ -265,16 +260,28 @@ def start_training(dataset_name="",
265
260
finetune = True ,
266
261
):
267
262
263
+
268
264
global training_process
269
265
266
+ path_project = os .path .join (path_data , dataset_name + "_pinyin" )
267
+
268
+ if os .path .isdir (path_project )== False :
269
+ yield f"There is not project with name { dataset_name } " ,gr .update (interactive = True ),gr .update (interactive = False )
270
+ return
271
+
272
+ file_raw = os .path .join (path_project ,"raw.arrow" )
273
+ if os .path .isfile (file_raw )== False :
274
+ yield f"There is no file { file_raw } " ,gr .update (interactive = True ),gr .update (interactive = False )
275
+ return
276
+
270
277
# Check if a training process is already running
271
278
if training_process is not None :
272
279
return "Train run already!" ,gr .update (interactive = False ),gr .update (interactive = True )
273
280
274
281
yield "start train" ,gr .update (interactive = False ),gr .update (interactive = False )
275
282
276
283
# Command to run the training script with the specified arguments
277
- cmd = f"{ python_executable } finetune-cli.py --exp_name { exp_name } " \
284
+ cmd = f"accelerate launch finetune-cli.py --exp_name { exp_name } " \
278
285
f"--learning_rate { learning_rate } " \
279
286
f"--batch_size_per_gpu { batch_size_per_gpu } " \
280
287
f"--batch_size_type { batch_size_type } " \
@@ -346,6 +353,8 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
346
353
path_project_wavs = os .path .join (path_project ,"wavs" )
347
354
file_metadata = os .path .join (path_project ,"metadata.csv" )
348
355
356
+ if audio_files is None :return "You need to load an audio file."
357
+
349
358
if os .path .isdir (path_project_wavs ):
350
359
shutil .rmtree (path_project_wavs )
351
360
@@ -356,16 +365,17 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
356
365
357
366
if user :
358
367
file_audios = [file for format in ('*.wav' , '*.ogg' , '*.opus' , '*.mp3' , '*.flac' ) for file in glob (os .path .join (path_dataset , format ))]
368
+ if file_audios == []:return "No audio file was found in the dataset."
359
369
else :
360
370
file_audios = audio_files
361
-
362
- print ([file_audios ])
371
+
363
372
364
373
alpha = 0.5
365
374
_max = 1.0
366
375
slicer = Slicer (24000 )
367
376
368
377
num = 0
378
+ error_num = 0
369
379
data = ""
370
380
for file_audio in progress .tqdm (file_audios , desc = "transcribe files" ,total = len ((file_audios ))):
371
381
@@ -381,18 +391,26 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
381
391
if (tmp_max > 1 ):chunk /= tmp_max
382
392
chunk = (chunk / tmp_max * (_max * alpha )) + (1 - alpha ) * chunk
383
393
wavfile .write (file_segment ,24000 , (chunk * 32767 ).astype (np .int16 ))
394
+
395
+ try :
396
+ text = transcribe (file_segment ,language )
397
+ text = text .lower ().strip ().replace ('"' ,"" )
384
398
385
- text = transcribe (file_segment ,language )
386
- text = text .lower ().strip ().replace ('"' ,"" )
399
+ data += f"{ name_segment } |{ text } \n "
387
400
388
- data += f"{ name_segment } |{ text } \n "
401
+ num += 1
402
+ except :
403
+ error_num += 1
389
404
390
- num += 1
391
-
392
405
with open (file_metadata ,"w" ,encoding = "utf-8" ) as f :
393
406
f .write (data )
394
-
395
- return f"transcribe complete samples : { num } in path { path_project_wavs } "
407
+
408
+ if error_num != []:
409
+ error_text = f"\n error files : { error_num } "
410
+ else :
411
+ error_text = ""
412
+
413
+ return f"transcribe complete samples : { num } \n path : { path_project_wavs } { error_text } "
396
414
397
415
def format_seconds_to_hms (seconds ):
398
416
hours = int (seconds / 3600 )
@@ -408,6 +426,8 @@ def create_metadata(name_project,progress=gr.Progress()):
408
426
file_raw = os .path .join (path_project ,"raw.arrow" )
409
427
file_duration = os .path .join (path_project ,"duration.json" )
410
428
file_vocab = os .path .join (path_project ,"vocab.txt" )
429
+
430
+ if os .path .isfile (file_metadata )== False : return "The file was not found in " + file_metadata
411
431
412
432
with open (file_metadata ,"r" ,encoding = "utf-8" ) as f :
413
433
data = f .read ()
@@ -419,11 +439,18 @@ def create_metadata(name_project,progress=gr.Progress()):
419
439
count = data .split ("\n " )
420
440
lenght = 0
421
441
result = []
442
+ error_files = []
422
443
for line in progress .tqdm (data .split ("\n " ),total = count ):
423
444
sp_line = line .split ("|" )
424
445
if len (sp_line )!= 2 :continue
425
- name_audio ,text = sp_line [:2 ]
446
+ name_audio ,text = sp_line [:2 ]
447
+
426
448
file_audio = os .path .join (path_project_wavs , name_audio + ".wav" )
449
+
450
+ if os .path .isfile (file_audio )== False :
451
+ error_files .append (file_audio )
452
+ continue
453
+
427
454
duraction = get_audio_duration (file_audio )
428
455
if duraction < 2 and duraction > 15 :continue
429
456
if len (text )< 4 :continue
@@ -439,6 +466,10 @@ def create_metadata(name_project,progress=gr.Progress()):
439
466
440
467
lenght += duraction
441
468
469
+ if duration_list == []:
470
+ error_files_text = "\n " .join (error_files )
471
+ return f"Error: No audio files found in the specified path : \n { error_files_text } "
472
+
442
473
min_second = round (min (duration_list ),2 )
443
474
max_second = round (max (duration_list ),2 )
444
475
@@ -450,9 +481,15 @@ def create_metadata(name_project,progress=gr.Progress()):
450
481
json .dump ({"duration" : duration_list }, f , ensure_ascii = False )
451
482
452
483
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
484
+ if os .path .isfile (file_vocab_finetune == False ):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
453
485
shutil .copy2 (file_vocab_finetune , file_vocab )
454
-
455
- return f"prepare complete \n samples : { len (text_list )} \n time data : { format_seconds_to_hms (lenght )} \n min sec : { min_second } \n max sec : { max_second } \n file_arrow : { file_raw } \n "
486
+
487
+ if error_files != []:
488
+ error_text = "error files\n " + "\n " .join (error_files )
489
+ else :
490
+ error_text = ""
491
+
492
+ return f"prepare complete \n samples : { len (text_list )} \n time data : { format_seconds_to_hms (lenght )} \n min sec : { min_second } \n max sec : { max_second } \n file_arrow : { file_raw } \n { error_text } "
456
493
457
494
def check_user (value ):
458
495
return gr .update (visible = not value ),gr .update (visible = value )
@@ -466,15 +503,19 @@ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_w
466
503
data = json .load (file )
467
504
468
505
duration_list = data ['duration' ]
506
+
469
507
samples = len (duration_list )
470
508
471
- gpu_properties = torch .cuda .get_device_properties (0 )
472
- total_memory = gpu_properties .total_memory / (1024 ** 3 )
509
+ if torch .cuda .is_available ():
510
+ gpu_properties = torch .cuda .get_device_properties (0 )
511
+ total_memory = gpu_properties .total_memory / (1024 ** 3 )
512
+ elif torch .backends .mps .is_available ():
513
+ total_memory = psutil .virtual_memory ().available / (1024 ** 3 )
473
514
474
515
if batch_size_type == "frame" :
475
516
batch = int (total_memory * 0.5 )
476
517
batch = (lambda num : num + 1 if num % 2 != 0 else num )(batch )
477
- batch_size_per_gpu = int (36800 / batch )
518
+ batch_size_per_gpu = int (38400 / batch )
478
519
else :
479
520
batch_size_per_gpu = int (total_memory / 8 )
480
521
batch_size_per_gpu = (lambda num : num + 1 if num % 2 != 0 else num )(batch_size_per_gpu )
@@ -509,13 +550,12 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
509
550
if ema_model_state_dict is not None :
510
551
new_checkpoint = {'ema_model_state_dict' : ema_model_state_dict }
511
552
torch .save (new_checkpoint , new_checkpoint_path )
512
- print ( f"New checkpoint saved at: { new_checkpoint_path } " )
553
+ return f"New checkpoint saved at: { new_checkpoint_path } "
513
554
else :
514
- print ( "No 'ema_model_state_dict' found in the checkpoint." )
555
+ return "No 'ema_model_state_dict' found in the checkpoint."
515
556
516
557
except Exception as e :
517
- print (f"An error occurred: { e } " )
518
-
558
+ return f"An error occurred: { e } "
519
559
520
560
def vocab_check (project_name ):
521
561
name_project = project_name + "_pinyin"
@@ -524,12 +564,17 @@ def vocab_check(project_name):
524
564
file_metadata = os .path .join (path_project , "metadata.csv" )
525
565
526
566
file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
567
+ if os .path .isfile (file_vocab )== False :
568
+ return f"the file { file_vocab } not found !"
527
569
528
570
with open (file_vocab ,"r" ,encoding = "utf-8" ) as f :
529
571
data = f .read ()
530
572
531
573
vocab = data .split ("\n " )
532
574
575
+ if os .path .isfile (file_metadata )== False :
576
+ return f"the file { file_metadata } not found !"
577
+
533
578
with open (file_metadata ,"r" ,encoding = "utf-8" ) as f :
534
579
data = f .read ()
535
580
@@ -548,6 +593,7 @@ def vocab_check(project_name):
548
593
549
594
if miss_symbols == []:info = "You can train using your language !"
550
595
else :info = f"The following symbols are missing in your language : { len (miss_symbols )} \n \n " + "\n " .join (miss_symbols )
596
+
551
597
return info
552
598
553
599
@@ -652,8 +698,9 @@ def vocab_check(project_name):
652
698
with gr .TabItem ("reduse checkpoint" ):
653
699
txt_path_checkpoint = gr .Text (label = "path checkpoint :" )
654
700
txt_path_checkpoint_small = gr .Text (label = "path output :" )
701
+ txt_info_reduse = gr .Text (label = "info" ,value = "" )
655
702
reduse_button = gr .Button ("reduse" )
656
- reduse_button .click (fn = extract_and_save_ema_model ,inputs = [txt_path_checkpoint ,txt_path_checkpoint_small ])
703
+ reduse_button .click (fn = extract_and_save_ema_model ,inputs = [txt_path_checkpoint ,txt_path_checkpoint_small ], outputs = [ txt_info_reduse ] )
657
704
658
705
with gr .TabItem ("vocab check experiment" ):
659
706
check_button = gr .Button ("check vocab" )
@@ -680,10 +727,4 @@ def main(port, host, share, api):
680
727
)
681
728
682
729
if __name__ == "__main__" :
683
- name = "my_speak"
684
-
685
- #create_data_project(name)
686
- #transcribe_all(name)
687
- #create_metadata(name)
688
-
689
730
main ()
0 commit comments