Compare commits

..

20 commits

Author SHA1 Message Date
58fef04e2c fix open ai api call 2024-12-20 09:28:06 -05:00
408a23008d set file correctly in request for real this time 2024-12-18 20:48:24 -05:00
86fba439ef set file correctly in request 2024-12-18 20:44:58 -05:00
3889c80509 set language how the api wants it 2024-12-18 20:40:34 -05:00
2a3615226e set model and don't delete file if requeing 2024-12-18 20:37:12 -05:00
ae1429ff63 replace local whisper with calls to openai 2024-12-18 20:24:18 -05:00
d816309060 safely unwrap process error code 2024-12-18 19:48:03 -05:00
7f1fa73bb0 delete files after processing 2024-12-17 23:25:00 -05:00
2743c8ca3c remove meilisearch logs 2024-12-17 23:07:52 -05:00
165521f11d limit prefect to 3 2024-12-17 23:03:55 -05:00
2c03d5a9bd explicitly set tokio to multithreaded 2024-12-17 22:55:27 -05:00
d88d7473c7 logging to figure out why meilisearch isn't working 2024-12-17 22:31:11 -05:00
86e6799edf add transcripts to meilisearch 2024-12-17 21:42:58 -05:00
af3fc06655 actually pass audio file to whisperx 2024-12-17 18:25:32 -05:00
415cebf942 fix output directory parameter 2024-12-17 18:18:04 -05:00
30d7a061b6 fix model parameter 2024-12-17 18:06:30 -05:00
b34f97fa4d log output from whisperx 2024-12-17 17:37:16 -05:00
af6376e2a0 add mutex around call transcription 2024-12-17 00:32:36 -05:00
4f5c1f8ac4 add logging 2024-12-16 23:55:13 -05:00
99a67fcfa8 setup mq worker 2024-12-16 23:21:45 -05:00
4 changed files with 3876 additions and 0 deletions

3560
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

20
Cargo.toml Normal file
View file

@ -0,0 +1,20 @@
[package]
name = "whisper-worker"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.42.0", features = ["rt-multi-thread", "rt", "macros", "time"] }
chrono = {version = "0.4.39", features = ["serde"] }
serde = { version = "1.0.216", features = ["derive"] }
serde_json = "1.0.133"
serde_with = "3.11.0"
lapin = "2.5.0"
confy = "0.6.1"
tokio-executor-trait = "2.1.3"
tokio-reactor-trait = "1.1.0"
log = "0.4.22"
log4rs = "1.3.0"
meilisearch-sdk = "0.27.1"
uuid = { version = "1.11.0", features = ["v4"] }
async-openai = "0.26.0"

33
log4rs.yaml Normal file
View file

@ -0,0 +1,33 @@
# Scan this file for changes every 30 seconds
refresh_rate: 30 seconds
appenders:
# An appender named "stdout" that writes to stdout
stdout:
kind: console
# An appender named "requests" that writes to a file with a custom pattern encoder
requests:
kind: file
path: "log/requests.log"
encoder:
pattern: "{d} - {m}{n}"
# Set the default logging level to "warn" and attach the "stdout" appender to the root
root:
level: debug
appenders:
- stdout
loggers:
# Raise the maximum log level for events sent to the "app::backend::db" logger to "info"
app::backend::db:
level: info
# Route log events sent to the "app::requests" logger to the "requests" appender,
# and *not* the normal appenders installed at the root
app::requests:
level: info
appenders:
- requests
additive: false

263
src/main.rs Normal file
View file

@ -0,0 +1,263 @@
use serde_with::BoolFromInt;
use chrono::serde::ts_milliseconds;
use std::{env, fs};
use std::fs::File;
use std::io::Error;
use std::path::{Path, PathBuf};
use std::process::{Command, Output};
use std::sync::{Arc, Mutex};
use async_openai::Audio;
use async_openai::error::OpenAIError;
use async_openai::types::{AudioInput, AudioResponseFormat, CreateTranscriptionRequest, CreateTranscriptionRequestArgs, CreateTranscriptionResponseJson, InputSource};
use chrono::{DateTime, Utc};
use lapin::{Connection, ConnectionProperties, Consumer};
use lapin::message::{Delivery, DeliveryResult};
use lapin::options::{BasicAckOptions, BasicConsumeOptions, BasicQosOptions, BasicRejectOptions};
use lapin::types::FieldTable;
use serde::{Deserialize, Serialize};
use log::{error, info, trace, warn};
use meilisearch_sdk::client::Client;
use meilisearch_sdk::task_info::TaskInfo;
use meilisearch_sdk::tasks::Task;
use uuid::Uuid;
#[tokio::main]
async fn main() {
let ascii = r#"
__ ___ _ _ __ __ __ _
\ \ / / |_ (_) ___ | '_ \ ___ _ _ o O O\ \ / / ___ _ _ | |__ ___ _ _
\ \/\/ /| ' \ | | (_-< | .__/ / -_) | '_| o \ \/\/ / / _ \ | '_| | / / / -_) | '_|
\_/\_/ |_||_| _|_|_ /__/_ |_|__ \___| _|_|_ TS__[O] \_/\_/ \___/ _|_|_ |_\_\ \___| _|_|_
_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""| {======|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|
"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'./o--000'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'
"#;
println!("{ascii}");
let args: Vec<String> = env::args().collect();
// let cfg: AppConfig = confy::load_path(Path::new(&args[1])).expect("Couldn't read config");
let cfg: AppConfig = confy::load_path("./config.toml").expect("Couldn't read config");
match Path::new("log4rs.yaml").exists() {
true => log4rs::init_file("log4rs.yaml", Default::default()).unwrap(),
false => println!("No log4rs.yaml file found. Logging will not be enabled")
}
info!("Setting up mq consumer");
let options = ConnectionProperties::default()
.with_executor(tokio_executor_trait::Tokio::current())
.with_reactor(tokio_reactor_trait::Tokio);
let connection = Connection::connect(&cfg.rabbit_mq_config.connection_string, options).await.unwrap();
let channel = connection.create_channel().await.unwrap();
channel.basic_qos(3, BasicQosOptions::default()).await.unwrap();
let consumer = channel.basic_consume("transcribe", "whisper-worker", BasicConsumeOptions::default(), FieldTable::default()).await.unwrap();
consumer.set_delegate({
move |delivery: DeliveryResult| {
let meilisearch_client = Client::new(&cfg.meilisearch_config.connection_string, Some(&cfg.meilisearch_config.api_key)).expect("Couldn't create meilisearch client");
async move {
let delivery = match delivery {
Ok(Some(delivery)) => delivery,
Ok(None) => return,
Err(error) => {
warn!("Failed to consume queue message {}", error);
return;
}
};
let transcription_request: TranscriptionRequest = serde_json::from_slice(&delivery.data).unwrap();
let path = Path::new(&transcription_request.audio_file_path);
if !path.exists() {
warn!("File not found: {}", &transcription_request.audio_file_path);
return delivery.ack(BasicAckOptions::default()).await.unwrap();
}
let transcription_result = transcribe_call(path).await;
let transcript = match transcription_result {
Ok(response) => {
info!("Successfully transcribed {}, Output: {}", &transcription_request.audio_file_path, response.text);
response.text
}
Err(error) => {
error!("Failed to transcribe {}, {}", &transcription_request.audio_file_path, error);
return delivery.reject(BasicRejectOptions {
requeue: true
}).await.unwrap();
}
};
let meilisearch_task = match write_to_meilisearch(&meilisearch_client, transcription_request.call_metadata, transcript).await {
Ok(task) => task,
Err(error) => {
error!("Failed to send message to meilisearch, {}", error);
delete_file(path);
return delivery.ack(BasicAckOptions::default()).await.unwrap();
}
};
match wait_for_task_to_complete(&meilisearch_client, &meilisearch_task).await {
Ok(task) if task.is_failure() => {
error!("Failed to send message to meilisearch, {}", task.unwrap_failure());
}
Err(error) => {
error!("Failed to send message to meilisearch, {}", error);
}
_ => {}
}
delete_file(path);
delivery.ack(BasicAckOptions::default()).await.unwrap()
}
}
});
info!("Startup Complete!");
std::future::pending::<()>().await;
}
async fn transcribe_call(file_path: &Path) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
info!("Transcribing file {}", file_path.display());
let client = async_openai::Client::new();
let request = CreateTranscriptionRequestArgs::default()
.prompt("This is a public safety radio transmission. The speakers could be any of the following communicating between each other: dispatcher, law enforcement officer, fire fighter, or emergency medical services")
.language("en")
.model("whisper-1")
.file(file_path.to_str().unwrap(),)
.build()?;
client.audio().transcribe(request).await
}
async fn write_to_meilisearch(client: &Client, call: Call, transcript: String) -> Result<TaskInfo, meilisearch_sdk::errors::Error> {
let doc = MeilisearchCall {
id: String::from(Uuid::new_v4()),
transcript,
metadata: call
};
client
.index("calls")
.add_documents(&[doc], None)
.await
}
async fn wait_for_task_to_complete(client: &Client, task: &TaskInfo) -> Result<Task, meilisearch_sdk::errors::Error> {
loop {
let result = client.get_task(task).await;
match &result {
Ok(task) if task.is_pending() => {
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
}
_ => {
return result;
}
}
}
}
fn delete_file(file_path: &Path) {
fs::remove_file(file_path).unwrap_or_default();
}
#[derive(Serialize, Deserialize, Default, Debug)]
struct TranscriptionRequest {
audio_file_path: String,
call_metadata: Call
}
#[derive(Serialize, Deserialize, Debug)]
struct MeilisearchCall {
id: String,
transcript: String,
metadata: Call
}
#[serde_with::serde_as]
#[derive(Serialize, Deserialize, Default, Debug)]
struct Call {
freq: u32,
freq_error: i32,
signal: i32,
noise: i32,
source_num: u32,
recorder_num: u32,
tdma_slot: u32,
#[serde_as(as = "BoolFromInt")]
phase2_tdma: bool,
#[serde(with = "ts_milliseconds")]
start_time: DateTime<Utc>,
#[serde(with = "ts_milliseconds")]
stop_time: DateTime<Utc>,
#[serde_as(as = "BoolFromInt")]
emergency: bool,
priority: i32,
#[serde_as(as = "BoolFromInt")]
mode: bool,
#[serde_as(as = "BoolFromInt")]
duplex: bool,
#[serde_as(as = "BoolFromInt")]
encrypted: bool,
call_length: i32,
talkgroup: u64,
talkgroup_tag: String,
talkgroup_description: String,
talkgroup_group_tag: String,
talkgroup_group: String,
audio_type: String,
short_name: String,
#[serde(rename = "freqList")]
freq_list: Vec<CallFrequency>,
#[serde(rename = "srcList")]
src_list: Vec<CallSource>,
patched_talkgroups: Option<Vec<u32>>
}
#[serde_with::serde_as]
#[derive(Serialize, Deserialize, Debug)]
struct CallSource {
src: i64,
#[serde(with = "ts_milliseconds")]
time: DateTime<Utc>,
pos: f64,
#[serde_as(as = "BoolFromInt")]
emergency: bool,
signal_system: String,
tag: String
}
#[serde_with::serde_as]
#[derive(Serialize, Deserialize, Debug)]
struct CallFrequency {
freq: f64,
#[serde(with = "ts_milliseconds")]
time: DateTime<Utc>,
pos: f64,
len: f64,
error_count: i32,
spike_count: i32
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct AppConfig {
rabbit_mq_config: RabbitMqConfig,
meilisearch_config: MeilisearchConfig
}
impl Default for AppConfig {
fn default() -> Self {
panic!("Could not find config file")
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct RabbitMqConfig {
connection_string: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct MeilisearchConfig {
connection_string: String,
api_key: String,
}