Skip to content

Commit 5bfb777

Browse files
authored
Merge pull request #49 from pragmatrix/speech-detect-wav
Support recognizing from wav files in azure-transcribe, make speech g…
2 parents 93998af + 42670ee commit 5bfb777

File tree

7 files changed

+74
-28
lines changed

7 files changed

+74
-28
lines changed

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "context-switch"
3-
version = "1.0.1"
3+
version = "1.1.0"
44
edition = "2024"
55
rust-version = "1.88"
66

@@ -73,6 +73,10 @@ openai-api-rs = { workspace = true }
7373
serde_json = { workspace = true }
7474
chrono-tz = { version = "0.10.3" }
7575

76+
77+
# For recognizing audio files in azure-transcribe.
78+
playback = { path = "services/playback" }
79+
7680
[workspace.dependencies]
7781
tracing-subscriber = { version = "0.3.19" }
7882

audio-knife/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "audio-knife"
3-
version = "1.3.1"
3+
version = "1.4.0"
44
edition = "2024"
55

66
[profile.dev]

core/src/lib.rs

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub fn audio_msg_channel(format: AudioFormat) -> (AudioMsgProducer, AudioMsgCons
9494
#[derive(Debug)]
9595
pub struct AudioConsumer {
9696
pub format: AudioFormat,
97-
pub receiver: mpsc::Receiver<Vec<i16>>,
97+
pub receiver: mpsc::UnboundedReceiver<Vec<i16>>,
9898
}
9999

100100
impl AudioConsumer {
@@ -114,18 +114,10 @@ impl AudioConsumer {
114114
#[derive(Debug)]
115115
pub struct AudioProducer {
116116
pub format: AudioFormat,
117-
pub sender: mpsc::Sender<Vec<i16>>,
117+
pub sender: mpsc::UnboundedSender<Vec<i16>>,
118118
}
119119

120120
impl AudioProducer {
121-
// TODO: remove this function.
122-
pub fn produce_raw(&self, samples: Vec<i16>) -> Result<()> {
123-
self.produce(AudioFrame {
124-
format: self.format,
125-
samples,
126-
})
127-
}
128-
129121
pub fn produce(&self, frame: AudioFrame) -> Result<()> {
130122
if frame.format != self.format {
131123
bail!(
@@ -134,15 +126,14 @@ impl AudioProducer {
134126
frame.format
135127
);
136128
}
137-
self.sender
138-
.try_send(frame.samples)
139-
.context("Sending samples")
129+
self.sender.send(frame.samples).context("Sending samples")?;
130+
Ok(())
140131
}
141132
}
142133

143134
/// Create an unidirectional audio channel.
144135
pub fn audio_channel(format: AudioFormat) -> (AudioProducer, AudioConsumer) {
145-
let (producer, consumer) = mpsc::channel(256);
136+
let (producer, consumer) = mpsc::unbounded_channel();
146137
(
147138
AudioProducer {
148139
format,

core/src/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ impl AudioFormat {
3131
time::Duration::from_secs_f64(mono_sample_count as f64 / self.sample_rate as f64)
3232
}
3333

34+
// Architecture: This is used only in the examples anymore.
3435
pub fn new_channel(&self) -> (AudioProducer, AudioConsumer) {
3536
audio_channel(*self)
3637
}
3738

39+
#[deprecated(note = "Removed without replacement")]
3840
pub fn new_msg_channel(&self) -> (AudioMsgProducer, AudioMsgConsumer) {
3941
audio_msg_channel(*self)
4042
}

examples/azure-transcribe.rs

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,58 @@
1-
use std::{env, time::Duration};
1+
use std::{env, path::Path, time::Duration};
22

3-
use anyhow::{Context, Result};
3+
use anyhow::{Context, Result, bail};
44
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
55
use tokio::{
66
select,
77
sync::mpsc::{channel, unbounded_channel},
88
};
99

10-
use context_switch::{InputModality, OutputModality, services::AzureTranscribe};
10+
use context_switch::{AudioConsumer, InputModality, OutputModality, services::AzureTranscribe};
1111
use context_switch_core::{
1212
AudioFormat, AudioFrame, audio,
1313
conversation::{Conversation, Input},
1414
service::Service,
1515
};
1616

17+
const LANGUAGE: &str = "de-DE";
18+
1719
#[tokio::main]
1820
async fn main() -> Result<()> {
1921
dotenvy::dotenv_override()?;
2022
tracing_subscriber::fmt::init();
2123

24+
let mut args = env::args();
25+
match args.len() {
26+
1 => recognize_from_microphone().await?,
27+
2 => recognize_from_wav(Path::new(&args.nth(1).unwrap())).await?,
28+
_ => bail!("Invalid number of arguments, expect zero or one"),
29+
}
30+
Ok(())
31+
}
32+
33+
async fn recognize_from_wav(file: &Path) -> Result<()> {
34+
// For now we always convert to 16khz single channel (this is what we use internally for
35+
// testing).
36+
let format = AudioFormat {
37+
channels: 1,
38+
sample_rate: 16000,
39+
};
40+
41+
let frames = playback::audio_file_to_frames(file, format)?;
42+
if frames.is_empty() {
43+
bail!("No frames in the audio file")
44+
}
45+
46+
let (producer, input_consumer) = format.new_channel();
47+
48+
for frame in frames {
49+
producer.produce(frame)?;
50+
}
51+
52+
recognize(format, input_consumer).await
53+
}
54+
55+
async fn recognize_from_microphone() -> Result<()> {
2256
let host = cpal::default_host();
2357
let device = host
2458
.default_input_device()
@@ -33,7 +67,7 @@ async fn main() -> Result<()> {
3367
let sample_rate = config.sample_rate();
3468
let format = AudioFormat::new(channels, sample_rate.0);
3569

36-
let (producer, mut input_consumer) = format.new_channel();
70+
let (producer, input_consumer) = format.new_channel();
3771

3872
// Create and run the input stream
3973
let stream = device
@@ -56,19 +90,23 @@ async fn main() -> Result<()> {
5690

5791
stream.play().expect("Failed to play stream");
5892

59-
let language = "de-DE";
93+
recognize(format, input_consumer).await
94+
}
6095

96+
async fn recognize(format: AudioFormat, mut input_consumer: AudioConsumer) -> Result<()> {
6197
// TODO: clarify how to access configurations.
6298
let params = azure::transcribe::Params {
6399
host: None,
64100
region: Some(env::var("AZURE_REGION").expect("AZURE_REGION undefined")),
65101
subscription_key: env::var("AZURE_SUBSCRIPTION_KEY")
66102
.expect("AZURE_SUBSCRIPTION_KEY undefined"),
67-
language: language.into(),
103+
language: LANGUAGE.into(),
104+
speech_gate: false,
68105
};
69106

70107
let (output_producer, mut output_consumer) = unbounded_channel();
71-
let (conv_input_producer, conv_input_consumer) = channel(32);
108+
// For now this is more or less unbounded, because we push complete audio files for recognition.
109+
let (conv_input_producer, conv_input_consumer) = channel(16384);
72110

73111
let azure = AzureTranscribe;
74112
let mut conversation = azure.conversation(

services/azure/src/transcribe.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use async_trait::async_trait;
44
use azure_speech::recognizer::{self, Event};
55
use futures::StreamExt;
66
use serde::Deserialize;
7-
use tracing::error;
7+
use tracing::{error, info};
88

99
use crate::Host;
1010
use context_switch_core::{
@@ -20,6 +20,8 @@ pub struct Params {
2020
pub region: Option<String>,
2121
pub subscription_key: String,
2222
pub language: String,
23+
#[serde(default)]
24+
pub speech_gate: bool,
2325
}
2426

2527
#[derive(Debug)]
@@ -67,9 +69,18 @@ impl Service for AzureTranscribe {
6769
.into_header_for_infinite_file();
6870
stream! {
6971
yield wav_header;
70-
let mut speech_gate = make_speech_gate_processor_soft_rms(0.0025, 10., 300., 0.01);
71-
while let Some(Input::Audio{ frame }) = input.recv().await {
72-
let frame = speech_gate(&frame);
72+
let mut speech_gate =
73+
if params.speech_gate {
74+
info!("Enabling speech gate");
75+
Some(make_speech_gate_processor_soft_rms(0.0025, 10., 300., 0.01))
76+
}
77+
else {
78+
None
79+
};
80+
while let Some(Input::Audio{ mut frame }) = input.recv().await {
81+
if let Some(ref mut speech_gate) = speech_gate {
82+
frame = (speech_gate)(&frame);
83+
}
7384
yield frame.to_le_bytes();
7485
// <https://azure.microsoft.com/en-us/pricing/details/cognitive-services/speech-services/>
7586
// Speech to text hours are measured as the hours of audio _sent to the service_, billed in second increments.

services/playback/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl Service for Playback {
159159
}
160160

161161
/// Render the file into 100ms audio frames mono.
162-
fn audio_file_to_frames(path: &Path, format: AudioFormat) -> Result<Vec<AudioFrame>> {
162+
pub fn audio_file_to_frames(path: &Path, format: AudioFormat) -> Result<Vec<AudioFrame>> {
163163
check_supported_audio_type(&path.to_string_lossy(), None)?;
164164
let file = File::open(path).inspect_err(|e| {
165165
// We don't want to provide the resolved path to the user in an error message. Therefore we

0 commit comments

Comments
 (0)