whisper-api/src/main.rs

206 lines
7 KiB
Rust

use crate::chrono::serde::ts_milliseconds;
use axum::extract::Multipart;
use axum::handler::Handler;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::post;
use axum::{Extension, Router};
use chrono;
use chrono::{DateTime, Utc};
use lapin::options::BasicPublishOptions;
use lapin::{BasicProperties, Channel, Connection, ConnectionProperties};
use serde::{Deserialize, Serialize};
use serde_with::BoolFromInt;
use std::env;
use std::fs::File;
use std::io::Write;
use std::option::Option;
use std::path::{Path, PathBuf};
use log::error;
#[tokio::main]
async fn main() {
let ascii = r#"
____ __ ____ __ __ __ _______..______ _______ .______ ___ .______ __
\ \ / \ / / | | | | | | / || _ \ | ____|| _ \ / \ | _ \ | |
\ \/ \/ / | |__| | | | | (----`| |_) | | |__ | |_) | / ^ \ | |_) | | |
\ / | __ | | | \ \ | ___/ | __| | / / /_\ \ | ___/ | |
\ /\ / | | | | | | .----) | | | | |____ | |\ \----. / _____ \ | | | |
\__/ \__/ |__| |__| |__| |_______/ | _| |_______|| _| `._____| /__/ \__\ | _| |__|
"#;
println!("{ascii}");
let args: Vec<String> = env::args().collect();
let cfg: AppConfig = confy::load_path(Path::new(&args[1])).expect("Couldn't read config");
match Path::new("log4rs.yaml").exists() {
true => log4rs::init_file("log4rs.yaml", Default::default()).unwrap(),
false => println!("No log4rs.yaml file found. Logging will not be enabled")
}
let options = ConnectionProperties::default()
.with_executor(tokio_executor_trait::Tokio::current())
.with_reactor(tokio_reactor_trait::Tokio);
let connection = Connection::connect(&cfg.rabbit_mq_config.connection_string, options).await.unwrap();
let channel = connection.create_channel().await.unwrap();
let app = Router::new()
.route("/calls", post(upload_call))
.layer(Extension(cfg.clone()))
.layer(Extension(channel));
let listener = tokio::net::TcpListener::bind(&cfg.web_server_config.listener).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn upload_call(Extension(cfg): Extension<AppConfig>, Extension(channel): Extension<Channel>, mut multipart: Multipart) -> axum::response::Response {
let mut call_metadata: Option<Call> = None;
let mut call_file: Option<PathBuf> = None;
while let Some(field) = multipart.next_field().await.unwrap() {
if field.name().unwrap() == "call_json" {
let call: Call = match serde_json::from_str(&field.text().await.unwrap()) {
Ok(call) => call,
Err(error) => {
error!("{}", error);
return (StatusCode::BAD_REQUEST, "Failed to parse json").into_response();
}
};
call_metadata = Some(call);
} else if field.name().unwrap() == "call_audio" {
let file_name = field.file_name().unwrap();
let file_path = Path::new(&cfg.file_storage_path).join(&file_name);
let data = field.bytes().await.unwrap();
let mut file = match File::create(&file_path) {
Ok(file) => file,
Err(error) => {
error!("Could not create file {}: {}", &file_path.display(), error);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
match file.write_all(&data) {
Ok(_) => {
call_file = Some(file_path)
}
Err(error) => {
error!("Could not write to file {}: {}", &file_path.display(), error);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
}
}
let call_metadata: Call = match call_metadata {
Some(call_metadata) => call_metadata,
None => return (StatusCode::UNPROCESSABLE_ENTITY, "call_json is required").into_response()
};
let call_file: PathBuf = match call_file {
Some(call_file) => call_file,
None => return (StatusCode::UNPROCESSABLE_ENTITY, "call_audio is required").into_response()
};
if call_metadata.call_length < 2 {
return (StatusCode::UNPROCESSABLE_ENTITY, "Call too short").into_response()
}
let transcription_request = TranscriptionRequest {
audio_file_path: call_file.to_str().unwrap().parse().unwrap(),
call_metadata
};
channel.basic_publish("", "transcribe", BasicPublishOptions::default(), &*serde_json::to_vec(&transcription_request).unwrap(), BasicProperties::default()).await.expect("idk it broke");
StatusCode::CREATED.into_response()
}
#[derive(Serialize, Deserialize, Default, Debug)]
struct TranscriptionRequest {
audio_file_path: String,
call_metadata: Call
}
#[serde_with::serde_as]
#[derive(Serialize, Deserialize, Default, Debug)]
struct Call {
freq: u32,
freq_error: i32,
signal: i32,
noise: i32,
source_num: u32,
recorder_num: u32,
tdma_slot: u32,
#[serde_as(as = "BoolFromInt")]
phase2_tdma: bool,
#[serde(with = "ts_milliseconds")]
start_time: DateTime<Utc>,
#[serde(with = "ts_milliseconds")]
stop_time: DateTime<Utc>,
#[serde_as(as = "BoolFromInt")]
emergency: bool,
priority: i32,
#[serde_as(as = "BoolFromInt")]
mode: bool,
#[serde_as(as = "BoolFromInt")]
duplex: bool,
#[serde_as(as = "BoolFromInt")]
encrypted: bool,
call_length: i32,
talkgroup: u64,
talkgroup_tag: String,
talkgroup_description: String,
talkgroup_group_tag: String,
talkgroup_group: String,
audio_type: String,
short_name: String,
#[serde(rename = "freqList")]
freq_list: Vec<CallFrequency>,
#[serde(rename = "srcList")]
src_list: Vec<CallSource>,
patched_talkgroups: Option<Vec<u32>>
}
#[serde_with::serde_as]
#[derive(Serialize, Deserialize, Debug)]
struct CallSource {
src: i64,
#[serde(with = "ts_milliseconds")]
time: DateTime<Utc>,
pos: f64,
#[serde_as(as = "BoolFromInt")]
emergency: bool,
signal_system: String,
tag: String
}
#[serde_with::serde_as]
#[derive(Serialize, Deserialize, Debug)]
struct CallFrequency {
freq: f64,
#[serde(with = "ts_milliseconds")]
time: DateTime<Utc>,
pos: f64,
len: f64,
error_count: i32,
spike_count: i32
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct AppConfig {
rabbit_mq_config: RabbitMqConfig,
web_server_config: WebServerConfig,
file_storage_path: String
}
impl Default for AppConfig {
fn default() -> Self {
panic!("Could not find config file")
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct RabbitMqConfig {
connection_string: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct WebServerConfig {
listener: String
}