Compare commits
5 commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2b326c765d | |||
| bf1ee8efe8 | |||
| de9d4a4cab | |||
| 8a9ce50409 | |||
| 43b9edf213 |
4 changed files with 3256 additions and 0 deletions
2999
Cargo.lock
generated
Normal file
2999
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
18
Cargo.toml
Normal file
18
Cargo.toml
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
[package]
|
||||||
|
name = "whisper-api"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
axum = { version = "0.7.9", features = ["multipart"]}
|
||||||
|
tokio = { version = "1.42.0", features = ["rt-multi-thread"] }
|
||||||
|
chrono = {version = "0.4.39", features = ["serde"] }
|
||||||
|
serde = { version = "1.0.216", features = ["derive"] }
|
||||||
|
serde_json = "1.0.133"
|
||||||
|
serde_with = "3.11.0"
|
||||||
|
lapin = "2.5.0"
|
||||||
|
log = "0.4.22"
|
||||||
|
log4rs = "1.3.0"
|
||||||
|
confy = "0.6.1"
|
||||||
|
tokio-executor-trait = "2.1.3"
|
||||||
|
tokio-reactor-trait = "1.1.0"
|
||||||
33
log4rs.yaml
Normal file
33
log4rs.yaml
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Scan this file for changes every 30 seconds
|
||||||
|
refresh_rate: 30 seconds
|
||||||
|
|
||||||
|
appenders:
|
||||||
|
# An appender named "stdout" that writes to stdout
|
||||||
|
stdout:
|
||||||
|
kind: console
|
||||||
|
|
||||||
|
# An appender named "requests" that writes to a file with a custom pattern encoder
|
||||||
|
requests:
|
||||||
|
kind: file
|
||||||
|
path: "log/requests.log"
|
||||||
|
encoder:
|
||||||
|
pattern: "{d} - {m}{n}"
|
||||||
|
|
||||||
|
# Set the default logging level to "warn" and attach the "stdout" appender to the root
|
||||||
|
root:
|
||||||
|
level: debug
|
||||||
|
appenders:
|
||||||
|
- stdout
|
||||||
|
|
||||||
|
loggers:
|
||||||
|
# Raise the maximum log level for events sent to the "app::backend::db" logger to "info"
|
||||||
|
app::backend::db:
|
||||||
|
level: info
|
||||||
|
|
||||||
|
# Route log events sent to the "app::requests" logger to the "requests" appender,
|
||||||
|
# and *not* the normal appenders installed at the root
|
||||||
|
app::requests:
|
||||||
|
level: info
|
||||||
|
appenders:
|
||||||
|
- requests
|
||||||
|
additive: false
|
||||||
206
src/main.rs
Normal file
206
src/main.rs
Normal file
|
|
@ -0,0 +1,206 @@
|
||||||
|
use crate::chrono::serde::ts_milliseconds;
|
||||||
|
use axum::extract::Multipart;
|
||||||
|
use axum::handler::Handler;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum::routing::post;
|
||||||
|
use axum::{Extension, Router};
|
||||||
|
use chrono;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use lapin::options::BasicPublishOptions;
|
||||||
|
use lapin::{BasicProperties, Channel, Connection, ConnectionProperties};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_with::BoolFromInt;
|
||||||
|
use std::env;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::option::Option;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use log::error;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let ascii = r#"
|
||||||
|
____ __ ____ __ __ __ _______..______ _______ .______ ___ .______ __
|
||||||
|
\ \ / \ / / | | | | | | / || _ \ | ____|| _ \ / \ | _ \ | |
|
||||||
|
\ \/ \/ / | |__| | | | | (----`| |_) | | |__ | |_) | / ^ \ | |_) | | |
|
||||||
|
\ / | __ | | | \ \ | ___/ | __| | / / /_\ \ | ___/ | |
|
||||||
|
\ /\ / | | | | | | .----) | | | | |____ | |\ \----. / _____ \ | | | |
|
||||||
|
\__/ \__/ |__| |__| |__| |_______/ | _| |_______|| _| `._____| /__/ \__\ | _| |__|
|
||||||
|
|
||||||
|
"#;
|
||||||
|
println!("{ascii}");
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
let cfg: AppConfig = confy::load_path(Path::new(&args[1])).expect("Couldn't read config");
|
||||||
|
match Path::new("log4rs.yaml").exists() {
|
||||||
|
true => log4rs::init_file("log4rs.yaml", Default::default()).unwrap(),
|
||||||
|
false => println!("No log4rs.yaml file found. Logging will not be enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
let options = ConnectionProperties::default()
|
||||||
|
.with_executor(tokio_executor_trait::Tokio::current())
|
||||||
|
.with_reactor(tokio_reactor_trait::Tokio);
|
||||||
|
let connection = Connection::connect(&cfg.rabbit_mq_config.connection_string, options).await.unwrap();
|
||||||
|
let channel = connection.create_channel().await.unwrap();
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/calls", post(upload_call))
|
||||||
|
.layer(Extension(cfg.clone()))
|
||||||
|
.layer(Extension(channel));
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind(&cfg.web_server_config.listener).await.unwrap();
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn upload_call(Extension(cfg): Extension<AppConfig>, Extension(channel): Extension<Channel>, mut multipart: Multipart) -> axum::response::Response {
|
||||||
|
let mut call_metadata: Option<Call> = None;
|
||||||
|
let mut call_file: Option<PathBuf> = None;
|
||||||
|
|
||||||
|
while let Some(field) = multipart.next_field().await.unwrap() {
|
||||||
|
if field.name().unwrap() == "call_json" {
|
||||||
|
let call: Call = match serde_json::from_str(&field.text().await.unwrap()) {
|
||||||
|
Ok(call) => call,
|
||||||
|
Err(error) => {
|
||||||
|
error!("{}", error);
|
||||||
|
return (StatusCode::BAD_REQUEST, "Failed to parse json").into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
call_metadata = Some(call);
|
||||||
|
} else if field.name().unwrap() == "call_audio" {
|
||||||
|
let file_name = field.file_name().unwrap();
|
||||||
|
let file_path = Path::new(&cfg.file_storage_path).join(&file_name);
|
||||||
|
let data = field.bytes().await.unwrap();
|
||||||
|
let mut file = match File::create(&file_path) {
|
||||||
|
Ok(file) => file,
|
||||||
|
Err(error) => {
|
||||||
|
error!("Could not create file {}: {}", &file_path.display(), error);
|
||||||
|
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match file.write_all(&data) {
|
||||||
|
Ok(_) => {
|
||||||
|
call_file = Some(file_path)
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
error!("Could not write to file {}: {}", &file_path.display(), error);
|
||||||
|
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let call_metadata: Call = match call_metadata {
|
||||||
|
Some(call_metadata) => call_metadata,
|
||||||
|
None => return (StatusCode::UNPROCESSABLE_ENTITY, "call_json is required").into_response()
|
||||||
|
};
|
||||||
|
let call_file: PathBuf = match call_file {
|
||||||
|
Some(call_file) => call_file,
|
||||||
|
None => return (StatusCode::UNPROCESSABLE_ENTITY, "call_audio is required").into_response()
|
||||||
|
};
|
||||||
|
if call_metadata.call_length < 2 {
|
||||||
|
return (StatusCode::UNPROCESSABLE_ENTITY, "Call too short").into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
let transcription_request = TranscriptionRequest {
|
||||||
|
audio_file_path: call_file.to_str().unwrap().parse().unwrap(),
|
||||||
|
call_metadata
|
||||||
|
};
|
||||||
|
channel.basic_publish("", "transcribe", BasicPublishOptions::default(), &*serde_json::to_vec(&transcription_request).unwrap(), BasicProperties::default()).await.expect("idk it broke");
|
||||||
|
|
||||||
|
StatusCode::CREATED.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Default, Debug)]
|
||||||
|
struct TranscriptionRequest {
|
||||||
|
audio_file_path: String,
|
||||||
|
call_metadata: Call
|
||||||
|
}
|
||||||
|
|
||||||
|
#[serde_with::serde_as]
|
||||||
|
#[derive(Serialize, Deserialize, Default, Debug)]
|
||||||
|
struct Call {
|
||||||
|
freq: u32,
|
||||||
|
freq_error: i32,
|
||||||
|
signal: i32,
|
||||||
|
noise: i32,
|
||||||
|
source_num: u32,
|
||||||
|
recorder_num: u32,
|
||||||
|
tdma_slot: u32,
|
||||||
|
#[serde_as(as = "BoolFromInt")]
|
||||||
|
phase2_tdma: bool,
|
||||||
|
#[serde(with = "ts_milliseconds")]
|
||||||
|
start_time: DateTime<Utc>,
|
||||||
|
#[serde(with = "ts_milliseconds")]
|
||||||
|
stop_time: DateTime<Utc>,
|
||||||
|
#[serde_as(as = "BoolFromInt")]
|
||||||
|
emergency: bool,
|
||||||
|
priority: i32,
|
||||||
|
#[serde_as(as = "BoolFromInt")]
|
||||||
|
mode: bool,
|
||||||
|
#[serde_as(as = "BoolFromInt")]
|
||||||
|
duplex: bool,
|
||||||
|
#[serde_as(as = "BoolFromInt")]
|
||||||
|
encrypted: bool,
|
||||||
|
call_length: i32,
|
||||||
|
talkgroup: u64,
|
||||||
|
talkgroup_tag: String,
|
||||||
|
talkgroup_description: String,
|
||||||
|
talkgroup_group_tag: String,
|
||||||
|
talkgroup_group: String,
|
||||||
|
audio_type: String,
|
||||||
|
short_name: String,
|
||||||
|
#[serde(rename = "freqList")]
|
||||||
|
freq_list: Vec<CallFrequency>,
|
||||||
|
#[serde(rename = "srcList")]
|
||||||
|
src_list: Vec<CallSource>,
|
||||||
|
patched_talkgroups: Option<Vec<u32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[serde_with::serde_as]
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct CallSource {
|
||||||
|
src: i64,
|
||||||
|
#[serde(with = "ts_milliseconds")]
|
||||||
|
time: DateTime<Utc>,
|
||||||
|
pos: f64,
|
||||||
|
#[serde_as(as = "BoolFromInt")]
|
||||||
|
emergency: bool,
|
||||||
|
signal_system: String,
|
||||||
|
tag: String
|
||||||
|
}
|
||||||
|
|
||||||
|
#[serde_with::serde_as]
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct CallFrequency {
|
||||||
|
freq: f64,
|
||||||
|
#[serde(with = "ts_milliseconds")]
|
||||||
|
time: DateTime<Utc>,
|
||||||
|
pos: f64,
|
||||||
|
len: f64,
|
||||||
|
error_count: i32,
|
||||||
|
spike_count: i32
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct AppConfig {
|
||||||
|
rabbit_mq_config: RabbitMqConfig,
|
||||||
|
web_server_config: WebServerConfig,
|
||||||
|
file_storage_path: String
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
panic!("Could not find config file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct RabbitMqConfig {
|
||||||
|
connection_string: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct WebServerConfig {
|
||||||
|
listener: String
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue