189 lines
6.4 KiB
Rust
189 lines
6.4 KiB
Rust
use serde_with::BoolFromInt;
|
|
use chrono::serde::ts_milliseconds;
|
|
use std::env;
|
|
use std::fs::File;
|
|
use std::path::Path;
|
|
use std::process::Command;
|
|
use std::sync::{Arc, Mutex};
|
|
use chrono::{DateTime, Utc};
|
|
use lapin::{Connection, ConnectionProperties, Consumer};
|
|
use lapin::message::{Delivery, DeliveryResult};
|
|
use lapin::options::{BasicAckOptions, BasicConsumeOptions};
|
|
use lapin::types::FieldTable;
|
|
use serde::{Deserialize, Serialize};
|
|
use log::{info, warn};
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
let ascii = r#"
|
|
__ ___ _ _ __ __ __ _
|
|
\ \ / / |_ (_) ___ | '_ \ ___ _ _ o O O\ \ / / ___ _ _ | |__ ___ _ _
|
|
\ \/\/ /| ' \ | | (_-< | .__/ / -_) | '_| o \ \/\/ / / _ \ | '_| | / / / -_) | '_|
|
|
\_/\_/ |_||_| _|_|_ /__/_ |_|__ \___| _|_|_ TS__[O] \_/\_/ \___/ _|_|_ |_\_\ \___| _|_|_
|
|
_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""| {======|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|_|"""""|
|
|
"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'./o--000'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'"`-0-0-'
|
|
"#;
|
|
println!("{ascii}");
|
|
|
|
let args: Vec<String> = env::args().collect();
|
|
// let cfg: AppConfig = confy::load_path(Path::new(&args[1])).expect("Couldn't read config");
|
|
let cfg: AppConfig = confy::load_path("./config.toml").expect("Couldn't read config");
|
|
match Path::new("log4rs.yaml").exists() {
|
|
true => log4rs::init_file("log4rs.yaml", Default::default()).unwrap(),
|
|
false => println!("No log4rs.yaml file found. Logging will not be enabled")
|
|
}
|
|
|
|
info!("Setting up mq consumer");
|
|
let options = ConnectionProperties::default()
|
|
.with_executor(tokio_executor_trait::Tokio::current())
|
|
.with_reactor(tokio_reactor_trait::Tokio);
|
|
let connection = Connection::connect(&cfg.rabbit_mq_config.connection_string, options).await.unwrap();
|
|
let channel = connection.create_channel().await.unwrap();
|
|
let consumer = channel.basic_consume("transcribe", "whisper-worker", BasicConsumeOptions::default(), FieldTable::default()).await.unwrap();
|
|
|
|
let processing_lock = Arc::new(Mutex::new(()));
|
|
consumer.set_delegate({
|
|
let processing_lock = Arc::clone(&processing_lock);
|
|
move |delivery: DeliveryResult| {
|
|
let processing_lock = Arc::clone(&processing_lock);
|
|
async move {
|
|
info!("Consuming mq message");
|
|
let delivery = match delivery {
|
|
Ok(Some(delivery)) => delivery,
|
|
Ok(None) => return,
|
|
Err(error) => {
|
|
warn!("Failed to consume queue message {}", error);
|
|
return;
|
|
}
|
|
};
|
|
|
|
let transcription_request: TranscriptionRequest = serde_json::from_slice(&delivery.data).unwrap();
|
|
let path = Path::new(&transcription_request.audio_file_path);
|
|
if !path.exists() {
|
|
warn!("File not found: {}", &transcription_request.audio_file_path);
|
|
return delivery.ack(BasicAckOptions::default()).await.unwrap();
|
|
}
|
|
|
|
|
|
let result = {
|
|
info!("Waiting for lock");
|
|
let _lock = processing_lock.lock();
|
|
info!("Acquired lock!");
|
|
transcribe_call(path);
|
|
};
|
|
info!("Lock released");
|
|
|
|
info!("Call transcription done");
|
|
delivery.ack(BasicAckOptions::default()).await.unwrap()
|
|
}
|
|
}
|
|
});
|
|
|
|
info!("Waiting for messages");
|
|
std::future::pending::<()>().await;
|
|
}
|
|
|
|
fn transcribe_call(file_path: &Path) {
|
|
info!("Transcribing file {}", file_path.display());
|
|
let output_directory = file_path.parent().unwrap();
|
|
let output = Command::new("whisperx")
|
|
.args(["--language", "en"])
|
|
.args(["--model", "large-v3"])
|
|
.args(["--batch_size", "4"])
|
|
.args(["--compute_type", "int8"])
|
|
.args(["--output_format", "txt"])
|
|
.args(["--output_dir", output_directory.parent().unwrap().to_str().unwrap()])
|
|
.arg("file")
|
|
.output().expect("TODO: panic message");
|
|
info!("Transcription done, {}", &output.status);
|
|
info!("Std out: {}", String::from_utf8_lossy(&output.stdout));
|
|
info!("Std err: {}", String::from_utf8_lossy(&output.stderr));
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Default, Debug)]
|
|
struct TranscriptionRequest {
|
|
audio_file_path: String,
|
|
call_metadata: Call
|
|
}
|
|
|
|
#[serde_with::serde_as]
|
|
#[derive(Serialize, Deserialize, Default, Debug)]
|
|
struct Call {
|
|
freq: u32,
|
|
freq_error: i32,
|
|
signal: i32,
|
|
noise: i32,
|
|
source_num: u32,
|
|
recorder_num: u32,
|
|
tdma_slot: u32,
|
|
#[serde_as(as = "BoolFromInt")]
|
|
phase2_tdma: bool,
|
|
#[serde(with = "ts_milliseconds")]
|
|
start_time: DateTime<Utc>,
|
|
#[serde(with = "ts_milliseconds")]
|
|
stop_time: DateTime<Utc>,
|
|
#[serde_as(as = "BoolFromInt")]
|
|
emergency: bool,
|
|
priority: i32,
|
|
#[serde_as(as = "BoolFromInt")]
|
|
mode: bool,
|
|
#[serde_as(as = "BoolFromInt")]
|
|
duplex: bool,
|
|
#[serde_as(as = "BoolFromInt")]
|
|
encrypted: bool,
|
|
call_length: i32,
|
|
talkgroup: u64,
|
|
talkgroup_tag: String,
|
|
talkgroup_description: String,
|
|
talkgroup_group_tag: String,
|
|
talkgroup_group: String,
|
|
audio_type: String,
|
|
short_name: String,
|
|
#[serde(rename = "freqList")]
|
|
freq_list: Vec<CallFrequency>,
|
|
#[serde(rename = "srcList")]
|
|
src_list: Vec<CallSource>,
|
|
patched_talkgroups: Option<Vec<u32>>
|
|
}
|
|
|
|
#[serde_with::serde_as]
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
struct CallSource {
|
|
src: i64,
|
|
#[serde(with = "ts_milliseconds")]
|
|
time: DateTime<Utc>,
|
|
pos: f64,
|
|
#[serde_as(as = "BoolFromInt")]
|
|
emergency: bool,
|
|
signal_system: String,
|
|
tag: String
|
|
}
|
|
|
|
#[serde_with::serde_as]
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
struct CallFrequency {
|
|
freq: f64,
|
|
#[serde(with = "ts_milliseconds")]
|
|
time: DateTime<Utc>,
|
|
pos: f64,
|
|
len: f64,
|
|
error_count: i32,
|
|
spike_count: i32
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
struct AppConfig {
|
|
rabbit_mq_config: RabbitMqConfig,
|
|
file_storage_path: String
|
|
}
|
|
|
|
impl Default for AppConfig {
|
|
fn default() -> Self {
|
|
panic!("Could not find config file")
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
struct RabbitMqConfig {
|
|
connection_string: String,
|
|
}
|