From 65c6201c4666af6f800016ba6030ede05921dd74 Mon Sep 17 00:00:00 2001 From: vivek Date: Tue, 3 Jun 2025 16:49:08 +0530 Subject: [PATCH 1/2] feat: add CLI support for configuring Whisper transcription parameters --- diarize.py | 57 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/diarize.py b/diarize.py index dbaa021..574938b 100644 --- a/diarize.py +++ b/diarize.py @@ -89,16 +89,28 @@ help="if you have a GPU use 'cuda', otherwise 'cpu'", ) +parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for sampling") +parser.add_argument("--initial-prompt", type=str, default="", help="Initial prompt for context") +parser.add_argument("--hotwords", type=str, default="", help="Hotwords as a single string") +parser.add_argument("--repetition-penalty", type=float, default=1.1, help="Penalty for repeated tokens") +parser.add_argument("--best-of", type=int, default=10, help="Number of candidates when sampling") +parser.add_argument("--beam-size", type=int, default=10, help="Beam size for beam search") +parser.add_argument("--patience", type=float, default=1.2, help="Beam search patience") +parser.add_argument("--no-repeat-ngram-size", type=int, default=3, help="Prevent repeating ngrams of this size") +parser.add_argument("--chunk-length", type=int, default=30, help="Length of audio chunks in seconds") +parser.add_argument("--length-penalty", type=float, default=1.0, help="Penalty for shorter/longer sequences") +parser.add_argument("--condition-on-previous-text", action="store_true", help="Condition decoding on previous text") +parser.add_argument("--multilingual", action="store_true", help="Enable multilingual mode") +parser.add_argument("--vad-filter", action="store_true", help="Enable voice activity detection filter") +parser.add_argument("--without-timestamps", action="store_true", help="Do not include timestamps in output") + args = parser.parse_args() language = process_language_arg(args.language, args.model_name) if args.stemming: - # Isolate vocals from the rest of the audio - return_code = os.system( f'python -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o temp_outputs --device "{args.device}"' ) - if return_code != 0: logging.warning( "Source splitting failed, using original audio file. " @@ -115,9 +127,6 @@ else: vocal_target = args.audio - -# Transcribe the audio file - whisper_model = faster_whisper.WhisperModel( args.model_name, device=args.device, compute_type=mtypes[args.device] ) @@ -129,20 +138,32 @@ else [-1] ) +transcribe_kwargs = dict( + audio=audio_waveform, + language=language, + suppress_tokens=suppress_tokens, + log_progress=True, + multilingual=args.multilingual, + temperature=args.temperature, + initial_prompt=args.initial_prompt, + repetition_penalty=args.repetition_penalty, + best_of=args.best_of, + vad_filter=args.vad_filter, + without_timestamps=args.without_timestamps, + beam_size=args.beam_size, + patience=args.patience, + no_repeat_ngram_size=args.no_repeat_ngram_size, + chunk_length=args.chunk_length, + length_penalty=args.length_penalty, + condition_on_previous_text=args.condition_on_previous_text, + hotwords=args.hotwords, +) + if args.batch_size > 0: - transcript_segments, info = whisper_pipeline.transcribe( - audio_waveform, - language, - suppress_tokens=suppress_tokens, - batch_size=args.batch_size, - ) + transcribe_kwargs["batch_size"] = args.batch_size + transcript_segments, info = whisper_pipeline.transcribe(**transcribe_kwargs) else: - transcript_segments, info = whisper_model.transcribe( - audio_waveform, - language, - suppress_tokens=suppress_tokens, - vad_filter=True, - ) + transcript_segments, info = whisper_model.transcribe(**transcribe_kwargs) full_transcript = "".join(segment.text for segment in transcript_segments) From e94857fac6a8ce8cd51cf0161c623f4ba67dacb0 Mon Sep 17 00:00:00 2001 From: vivek Date: Tue, 3 Jun 2025 17:10:25 +0530 Subject: [PATCH 2/2] Remove unnecessary changes --- diarize.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/diarize.py b/diarize.py index 574938b..194ee53 100644 --- a/diarize.py +++ b/diarize.py @@ -108,9 +108,12 @@ language = process_language_arg(args.language, args.model_name) if args.stemming: + # Isolate vocals from the rest of the audio + return_code = os.system( f'python -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o temp_outputs --device "{args.device}"' ) + if return_code != 0: logging.warning( "Source splitting failed, using original audio file. " @@ -127,6 +130,8 @@ else: vocal_target = args.audio +# Transcribe the audio file + whisper_model = faster_whisper.WhisperModel( args.model_name, device=args.device, compute_type=mtypes[args.device] )