replace local whisper with calls to openai

This commit is contained in:
Christopher Moyer 2024-12-18 20:24:18 -05:00
parent d816309060
commit ae1429ff63
3 changed files with 224 additions and 55 deletions

200
Cargo.lock generated
View file

@ -154,6 +154,15 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "async-convert"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae"
dependencies = [
"async-trait",
]
[[package]]
name = "async-executor"
version = "1.13.1"
@ -252,6 +261,32 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "async-openai"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d226540b4ecf884b0fb4370008631ccbd9605cf82f98bb504ec8f2cee0810e8"
dependencies = [
"async-convert",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-reactor-trait"
version = "1.1.0"
@ -293,6 +328,20 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
@ -506,6 +555,16 @@ dependencies = [
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@ -637,6 +696,37 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_builder"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.90",
]
[[package]]
name = "derive_builder_macro"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn 2.0.90",
]
[[package]]
name = "des"
version = "0.8.1"
@ -753,6 +843,17 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "executor-trait"
version = "2.1.2"
@ -908,6 +1009,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
@ -1090,6 +1197,7 @@ dependencies = [
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
@ -1547,6 +1655,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -2010,11 +2128,13 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
@ -2034,6 +2154,22 @@ dependencies = [
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.8"
@ -2119,7 +2255,7 @@ checksum = "2a980454b497c439c274f2feae2523ed8138bbd3d323684e1435fec62f800481"
dependencies = [
"log",
"rustls",
"rustls-native-certs",
"rustls-native-certs 0.7.3",
"rustls-pki-types",
"rustls-webpki",
]
@ -2134,7 +2270,19 @@ dependencies = [
"rustls-pemfile",
"rustls-pki-types",
"schannel",
"security-framework",
"security-framework 2.11.1",
]
[[package]]
name = "rustls-native-certs"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3"
dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
"security-framework 3.1.0",
]
[[package]]
@ -2207,6 +2355,16 @@ dependencies = [
"sha2",
]
[[package]]
name = "secrecy"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
@ -2214,7 +2372,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81d3f8c9bfcc3cbb6b0179eb57042d75b1582bdc65c3cb95f3fa999509c03cbc"
dependencies = [
"bitflags 2.6.0",
"core-foundation 0.10.0",
"core-foundation-sys",
"libc",
"security-framework-sys",
@ -2222,9 +2393,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
version = "2.12.1"
version = "2.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2"
checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5"
dependencies = [
"core-foundation-sys",
"libc",
@ -2752,9 +2923,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.90",
]
[[package]]
name = "tracing-core"
version = "0.1.33"
@ -2785,6 +2968,12 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "unicase"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df"
[[package]]
name = "unicode-ident"
version = "1.0.14"
@ -2991,6 +3180,7 @@ dependencies = [
name = "whisper-worker"
version = "0.1.0"
dependencies = [
"async-openai",
"chrono",
"confy",
"lapin",

View file

@ -17,3 +17,4 @@ log = "0.4.22"
log4rs = "1.3.0"
meilisearch-sdk = "0.27.1"
uuid = { version = "1.11.0", features = ["v4"] }
async-openai = "0.26.0"

View file

@ -6,6 +6,9 @@ use std::io::Error;
use std::path::{Path, PathBuf};
use std::process::{Command, Output};
use std::sync::{Arc, Mutex};
use async_openai::Audio;
use async_openai::error::OpenAIError;
use async_openai::types::{AudioInput, AudioResponseFormat, CreateTranscriptionRequest, CreateTranscriptionRequestArgs, CreateTranscriptionResponseJson, InputSource};
use chrono::{DateTime, Utc};
use lapin::{Connection, ConnectionProperties, Consumer};
use lapin::message::{Delivery, DeliveryResult};
@ -47,12 +50,9 @@ _|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""| {======|_|"""""|_|"""""
channel.basic_qos(3, BasicQosOptions::default()).await.unwrap();
let consumer = channel.basic_consume("transcribe", "whisper-worker", BasicConsumeOptions::default(), FieldTable::default()).await.unwrap();
let processing_lock = Arc::new(Mutex::new(()));
consumer.set_delegate({
let processing_lock = Arc::clone(&processing_lock);
move |delivery: DeliveryResult| {
let meilisearch_client = Client::new(&cfg.meilisearch_config.connection_string, Some(&cfg.meilisearch_config.api_key)).expect("Couldn't create meilisearch client");
let processing_lock = Arc::clone(&processing_lock);
async move {
let delivery = match delivery {
Ok(Some(delivery)) => delivery,
@ -70,46 +70,20 @@ _|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""| {======|_|"""""|_|"""""
return delivery.ack(BasicAckOptions::default()).await.unwrap();
}
let transcription_result = transcribe_call(path).await;
let transcription_result = {
let _lock = processing_lock.lock();
transcribe_call(path)
};
match transcription_result {
Ok(result) if result.status.success() => {
info!("Successfully transcribed {}", &transcription_request.audio_file_path);
if !result.stdout.is_empty() {
trace!("Stdout: {}", String::from_utf8_lossy(&result.stdout));
}
if !result.stderr.is_empty() {
trace!("Stderr: {}", String::from_utf8_lossy(&result.stderr));
}
}
Ok(result) => {
error!("Failed to transcribe {}, Exit code {}", &transcription_request.audio_file_path, result.status.code().unwrap_or_else(|| -99));
error!("Stdout: {}", String::from_utf8_lossy(&result.stdout));
error!("Stderr: {}", String::from_utf8_lossy(&result.stderr));
delete_file(path);
return delivery.reject(BasicRejectOptions::default()).await.unwrap();
let transcript = match transcription_result {
Ok(response) => {
info!("Successfully transcribed {}, Output: {}", &transcription_request.audio_file_path, response.text);
response.text
}
Err(error) => {
error!("Failed to transcribe {}, {}", &transcription_request.audio_file_path, error);
delete_file(path);
return delivery.ack(BasicAckOptions::default()).await.unwrap();
}
}
let transcript_path = path.with_extension("txt");
let transcript = match fs::read_to_string(&transcript_path) {
Ok(transcript) => transcript,
Err(_) => {
error!("Failed to read transcript file {}", &transcript_path.display());
delete_file(path);
return delivery.ack(BasicAckOptions::default()).await.unwrap();
return delivery.reject(BasicRejectOptions {
requeue: true
}).await.unwrap();
}
};
@ -119,7 +93,6 @@ _|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""| {======|_|"""""|_|"""""
error!("Failed to send message to meilisearch, {}", error);
delete_file(path);
delete_file(transcript_path.as_path());
return delivery.ack(BasicAckOptions::default()).await.unwrap();
}
};
@ -134,7 +107,6 @@ _|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""| {======|_|"""""|_|"""""
}
delete_file(path);
delete_file(transcript_path.as_path());
delivery.ack(BasicAckOptions::default()).await.unwrap()
}
}
@ -144,18 +116,24 @@ _|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""| {======|_|"""""|_|"""""
std::future::pending::<()>().await;
}
fn transcribe_call(file_path: &Path) -> Result<Output, Error> {
async fn transcribe_call(file_path: &Path) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
info!("Transcribing file {}", file_path.display());
let output_directory = file_path.parent().unwrap();
Command::new("whisperx")
.args(["--language", "en"])
.args(["--model", "large-v3"])
.args(["--batch_size", "4"])
.args(["--compute_type", "int8"])
.args(["--output_format", "txt"])
.args(["--output_dir", output_directory.to_str().unwrap()])
.arg(file_path)
.output()
let client = async_openai::Client::new();
let audio_input = AudioInput {
source: InputSource::Path {
path: file_path.to_path_buf(),
}
};
let request = CreateTranscriptionRequestArgs::default()
.prompt("This is a public safety radio transmission. The speakers could be any of the following communicating between each other: dispatcher, law enforcement officer, fire fighter, or emergency medical services")
.language("English")
.response_format(AudioResponseFormat::Text)
.file(audio_input)
.build()?;
client.audio().transcribe(request).await
}
async fn write_to_meilisearch(client: &Client, call: Call, transcript: String) -> Result<TaskInfo, meilisearch_sdk::errors::Error> {