feat: token extractor middleware

This commit is contained in:
Ahmet Kaan Gümüş 2025-04-16 06:27:08 +03:00
parent 96199f71ef
commit 4f874d8789
9 changed files with 247 additions and 32 deletions

View file

@ -2,6 +2,7 @@ use std::sync::LazyLock;
use utils::naive_toml_parser;
mod middleware;
pub mod signal;
pub mod utils;

View file

@ -1,8 +1,11 @@
use rust_communication_server::signal::start_signalling;
use tracing::Level;
#[tokio::main]
async fn main() {
println!("Hello, world!");
tokio::spawn(start_signalling()).await.unwrap();
tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.init();
start_signalling().await;
}

83
server/src/middleware.rs Normal file
View file

@ -0,0 +1,83 @@
use std::{str::FromStr, sync::Arc};
use axum::{
extract::Request,
http::{self, HeaderMap, StatusCode},
middleware::Next,
response::IntoResponse,
};
use protocol::{SignalType, User};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct UserAndExpectedSignal {
pub user: User,
pub expected_signal: SignalType,
}
async fn extract_user_from_authorization_header(headers: &HeaderMap) -> Option<User> {
if let Some(authorization_header) = headers.get(http::header::AUTHORIZATION) {
dbg!(authorization_header);
if let Ok(authorization_header) = authorization_header.to_str() {
if let Some((bearer, authorization_token)) = authorization_header.split_once(' ') {
dbg!(
"Info: Verify | Http Header | {} || {}",
bearer,
authorization_token
);
if bearer.to_lowercase() == "bearer" {
let user = User {
username: authorization_token.to_string(),
};
return Some(user);
}
}
}
}
None
}
pub async fn verify_then_get_user(mut request: Request, next: Next) -> impl IntoResponse {
let headers = request.headers();
dbg!(headers);
if let Some(user) = extract_user_from_authorization_header(headers).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return next.run(request).await;
}
StatusCode::FORBIDDEN.into_response()
}
pub async fn verify_then_get_user_and_expected_signal(
mut request: Request,
next: Next,
) -> impl IntoResponse {
let headers = request.headers();
if let Some(user) = extract_user_from_authorization_header(headers).await {
if let Ok(expected_signal) = headers.get("EXPECTED_SIGNAL").unwrap().to_str() {
match SignalType::from_str(expected_signal) {
Ok(expected_signal) => {
let user_and_expected_signal = UserAndExpectedSignal {
user,
expected_signal,
};
let user_and_expected_signal = Arc::new(user_and_expected_signal);
request.extensions_mut().insert(user_and_expected_signal);
next.run(request).await
}
Err(err_val) => {
eprintln!(
"Error: Verify and Get Expected Signal | Signal Type Conversion | {}",
err_val
);
StatusCode::BAD_REQUEST.into_response()
}
}
} else {
StatusCode::BAD_REQUEST.into_response()
}
} else {
StatusCode::FORBIDDEN.into_response()
}
}

View file

@ -1,28 +1,45 @@
use std::sync::{Arc, LazyLock, RwLock};
use axum::{
Json, Router,
Extension, Json, Router,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
};
use axum_macros::debug_handler;
use protocol::Signal;
use protocol::{Signal, User, UserAndSignal};
use tokio::net::TcpListener;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
static SIGNALS: LazyLock<Arc<RwLock<Vec<Signal>>>> =
LazyLock::new(|| Arc::new(RwLock::new(vec![])));
use crate::middleware::{
UserAndExpectedSignal, verify_then_get_user, verify_then_get_user_and_expected_signal,
};
static USERS_AND_SIGNALS: LazyLock<RwLock<Vec<UserAndSignal>>> =
LazyLock::new(|| RwLock::new(vec![]));
pub async fn start_signalling() {
let route = route();
let listener = TcpListener::bind("0.0.0.0:4546").await.unwrap();
let route = route()
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http());
let listener = TcpListener::bind("192.168.1.3:4546").await.unwrap();
println!("http://192.168.1.3:4546");
axum::serve(listener, route).await.unwrap();
}
fn route() -> Router {
Router::new()
.route("/alive", get(alive))
.route("/", post(signal))
.route(
"/",
get(read_signal).route_layer(axum::middleware::from_fn(
verify_then_get_user_and_expected_signal,
)),
)
.route(
"/",
post(create_signal).route_layer(axum::middleware::from_fn(verify_then_get_user)),
)
}
async fn alive() -> impl IntoResponse {
@ -30,7 +47,33 @@ async fn alive() -> impl IntoResponse {
}
#[debug_handler]
async fn signal(Json(signal): Json<Signal>) -> impl IntoResponse {
SIGNALS.write().unwrap().push(signal);
async fn create_signal(
Extension(user): Extension<Arc<User>>,
Json(signal): Json<Signal>,
) -> impl IntoResponse {
let user = (*user).clone();
let user_and_signal = UserAndSignal::new(user, signal).await;
USERS_AND_SIGNALS.write().unwrap().push(user_and_signal);
StatusCode::OK
}
async fn read_signal(
Extension(user_and_expected_signal): Extension<Arc<UserAndExpectedSignal>>,
) -> impl IntoResponse {
let mut target_index = None;
let mut json_body = serde_json::json!("");
for (index, user_and_signal) in USERS_AND_SIGNALS.read().unwrap().iter().enumerate() {
if user_and_signal.signal.get_signal_type() == user_and_expected_signal.expected_signal {
json_body = serde_json::json!(user_and_signal);
target_index = Some(index);
}
}
match target_index {
Some(target_index) => {
USERS_AND_SIGNALS.write().unwrap().remove(target_index);
(StatusCode::OK, Json(json_body)).into_response()
}
None => StatusCode::BAD_REQUEST.into_response(),
}
}