From 11f101638e989fc2161c17861f741e78ce64116f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20Kaan=20G=C3=9CM=C3=9C=C5=9E?= <96421894+Tahinli@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:01:56 +0300 Subject: [PATCH] feat: :sparkles: post and comment length limit fix: :ambulance: user extraction from url --- Cargo.toml | 2 +- configs/server_config.toml | 2 + migrations/20241215002127_user_contact.up.sql | 5 + src/error.rs | 2 + src/feature/user_contact.rs | 17 +++ src/lib.rs | 10 +- src/routing.rs | 3 - src/routing/admin.rs | 12 +- src/routing/admin/user.rs | 14 ++- src/routing/comment.rs | 52 +++++--- src/routing/login.rs | 18 ++- src/routing/middleware.rs | 112 ++++++++++-------- src/routing/post.rs | 40 +++++-- 13 files changed, 185 insertions(+), 104 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e9f910e..9c5442f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ strip = "symbols" tokio = { version = "1.43.0", default-features = false,features = ["macros", "rt-multi-thread", "time"] } serde = { version = "1.0.217", default-features = false, features = ["derive"] } serde_json = { version = "1.0.135" , default-features = false} -axum = { version = "0.8.1", default-features = false, features = ["http1", "json", "tokio"]} +axum = { version = "0.8.1", default-features = false, features = ["http1", "json", "original-uri", "tokio"]} chrono = { version = "0.4.39", default-features = false, features = ["serde"] } jwt-simple = { version = "0.12.11", default-features = false, features = ["pure-rust"] } lettre = { version = "0.11.11", default-features = false, features = ["builder", "serde", "smtp-transport", "tokio1-rustls-tls"] } diff --git a/configs/server_config.toml b/configs/server_config.toml index a39ad56..d5756f2 100644 --- a/configs/server_config.toml +++ b/configs/server_config.toml @@ -4,3 +4,5 @@ otp_time_limit = 15 login_token_expiration_time_limit = 15 login_token_refresh_time_limit = 30 concurrency_limit = -1 +post_length_limit = 1048576 +comment_length_limit = 1048576 diff --git a/migrations/20241215002127_user_contact.up.sql b/migrations/20241215002127_user_contact.up.sql index a4638f8..725ffd5 100644 --- a/migrations/20241215002127_user_contact.up.sql +++ b/migrations/20241215002127_user_contact.up.sql @@ -6,3 +6,8 @@ CREATE TABLE IF NOT EXISTS "user_contact"( PRIMARY KEY (user_id, contact_id), UNIQUE (contact_id, contact_value) ); + +INSERT INTO "user_contact"(user_id, contact_id, contact_value) +VALUES (0, 0, 'builder@rust_forum.com') +ON CONFLICT(user_id, contact_id) DO UPDATE SET +"contact_value" = 'builder@rust_forum.com'; diff --git a/src/error.rs b/src/error.rs index 5268800..3342e3d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,6 +5,7 @@ pub enum ForumInputError { ForbiddenCharacter, ForbiddenString, EmptyParameter, + TooLong, } impl std::fmt::Display for ForumInputError { @@ -17,6 +18,7 @@ impl std::fmt::Display for ForumInputError { write!(f, "Forbidden String Detected") } &ForumInputError::EmptyParameter => write!(f, "Parameter is Empty"), + ForumInputError::TooLong => write!(f, "Input is Too Long"), } } } diff --git a/src/feature/user_contact.rs b/src/feature/user_contact.rs index bc24eb9..c418d35 100644 --- a/src/feature/user_contact.rs +++ b/src/feature/user_contact.rs @@ -4,6 +4,8 @@ use crate::database::user_contact; use super::user::User; +const CONTACT_DEFAULT_ID_FOR_EMAIL: i64 = 0; + #[derive(Debug, Serialize, Deserialize)] pub struct UserContact { pub user_id: i64, @@ -20,10 +22,21 @@ impl UserContact { user_contact::create(&user.user_id, contact_id, contact_value).await } + pub async fn create_for_email( + user: &User, + contact_value: &String, + ) -> Result { + user_contact::create(&user.user_id, &CONTACT_DEFAULT_ID_FOR_EMAIL, contact_value).await + } + pub async fn read(user: &User, contact_id: &i64) -> Result { user_contact::read(&user.user_id, contact_id).await } + pub async fn read_for_email(user: &User) -> Result { + user_contact::read(&user.user_id, &CONTACT_DEFAULT_ID_FOR_EMAIL).await + } + pub async fn read_for_value( contact_id: &i64, contact_value: &String, @@ -31,6 +44,10 @@ impl UserContact { user_contact::read_for_value(contact_id, contact_value).await } + pub async fn read_for_email_value(contact_value: &String) -> Result { + user_contact::read_for_value(&CONTACT_DEFAULT_ID_FOR_EMAIL, contact_value).await + } + pub async fn update( user: &User, contact_id: &i64, diff --git a/src/lib.rs b/src/lib.rs index 72f208a..2dba743 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,6 @@ pub mod utils; use std::sync::LazyLock; -use sqlx::{Pool, Postgres}; use utils::naive_toml_parser; pub static SERVER_CONFIG: LazyLock = LazyLock::new(ServerConfig::default); @@ -51,6 +50,8 @@ pub struct ServerConfig { pub login_token_expiration_time_limit: usize, pub login_token_refresh_time_limit: usize, pub concurrency_limit: usize, + pub post_length_limit: usize, + pub comment_length_limit: usize, } impl Default for ServerConfig { @@ -72,14 +73,11 @@ impl Default for ServerConfig { ), login_token_refresh_time_limit: value_or_max(server_configs.pop_front().unwrap()), concurrency_limit: value_or_semaphore_max(server_configs.pop_front().unwrap()), + post_length_limit: server_configs.pop_front().unwrap().parse().unwrap(), + comment_length_limit: server_configs.pop_front().unwrap().parse().unwrap(), } } else { panic!("Server Config File Must Include [server_config] at the First Line") } } } - -#[derive(Debug, Clone)] -pub struct AppState { - pub database_connection: Pool, -} diff --git a/src/routing.rs b/src/routing.rs index bd35efd..18190e5 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -12,7 +12,6 @@ pub mod user; pub mod user_contact; use axum::{http::StatusCode, response::IntoResponse, routing::get, Router}; -use middleware::by_authorization_token; use tower::limit::ConcurrencyLimitLayer; use tower_http::{cors::CorsLayer, trace::TraceLayer}; @@ -32,8 +31,6 @@ pub async fn route(concurrency_limit: &usize) -> Router { .nest("/contacts", contact::route()) .nest("/user_contacts", user_contact::route()) .nest("/admin", admin::route()) - // todo just for beta I think - .route_layer(axum::middleware::from_fn(by_authorization_token)) .layer(CorsLayer::permissive()) .layer(ConcurrencyLimitLayer::new(*concurrency_limit)) .layer(TraceLayer::new_for_http()) diff --git a/src/routing/admin.rs b/src/routing/admin.rs index cbe411d..0bb6d65 100644 --- a/src/routing/admin.rs +++ b/src/routing/admin.rs @@ -7,12 +7,20 @@ pub mod role; pub mod user; pub mod user_contact; -use axum::Router; +use axum::{response::IntoResponse, routing::get, Router}; -use super::middleware::builder_or_admin_by_authorization_token; +use super::middleware::{builder_or_admin_by_authorization_token, by_uri_then_insert}; + +async fn a() -> impl IntoResponse { + "HEY" +} pub fn route() -> Router { Router::new() + .route( + "/users/{user_id}", + get(a).route_layer(axum::middleware::from_fn(by_uri_then_insert)), + ) .nest("/logins", login::route()) .nest("/users", user::route()) .nest("/roles", role::route()) diff --git a/src/routing/admin/user.rs b/src/routing/admin/user.rs index 91bc758..124e56e 100644 --- a/src/routing/admin/user.rs +++ b/src/routing/admin/user.rs @@ -10,7 +10,10 @@ use axum::{ use chrono::NaiveDate; use serde::{Deserialize, Serialize}; -use crate::{feature::user::User, routing::middleware::by_uri_then_insert}; +use crate::{ + feature::{user::User, user_contact::UserContact}, + routing::middleware::by_uri_then_insert, +}; #[derive(Debug, Serialize, Deserialize)] struct CreateUser { @@ -18,6 +21,7 @@ struct CreateUser { surname: String, gender: bool, birth_date: NaiveDate, + email: String, } #[derive(Debug, Serialize, Deserialize)] @@ -78,7 +82,13 @@ async fn create(Json(create_user): Json) -> impl IntoResponse { ) .await { - Ok(user) => (StatusCode::CREATED, Json(serde_json::json!(user))), + Ok(user) => match UserContact::create_for_email(&user, &create_user.email).await { + Ok(_) => (StatusCode::CREATED, Json(serde_json::json!(user))), + Err(err_val) => ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!(err_val.to_string())), + ), + }, Err(err_val) => ( StatusCode::BAD_REQUEST, Json(serde_json::json!(err_val.to_string())), diff --git a/src/routing/comment.rs b/src/routing/comment.rs index 48c2860..e9dd458 100644 --- a/src/routing/comment.rs +++ b/src/routing/comment.rs @@ -9,7 +9,11 @@ use axum::{ }; use serde::{Deserialize, Serialize}; -use crate::feature::{comment::Comment, user::User}; +use crate::{ + error::ForumInputError, + feature::{comment::Comment, user::User}, + SERVER_CONFIG, +}; use super::middleware::by_authorization_token_then_insert; @@ -52,18 +56,25 @@ async fn create( Extension(user): Extension>, Json(create_comment): Json, ) -> impl IntoResponse { - match Comment::create( - &user.user_id, - &create_comment.post_id, - &create_comment.comment, - ) - .await - { - Ok(comment) => (StatusCode::CREATED, Json(serde_json::json!(comment))), - Err(err_val) => ( + if create_comment.comment.len() > SERVER_CONFIG.comment_length_limit { + return ( StatusCode::BAD_REQUEST, - Json(serde_json::json!(err_val.to_string())), - ), + Json(serde_json::json!(ForumInputError::TooLong)), + ); + } else { + match Comment::create( + &user.user_id, + &create_comment.post_id, + &create_comment.comment, + ) + .await + { + Ok(comment) => (StatusCode::CREATED, Json(serde_json::json!(comment))), + Err(err_val) => ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!(err_val.to_string())), + ), + } } } @@ -82,12 +93,19 @@ async fn update( Path(comment_id): Path, Json(update_comment): Json, ) -> impl IntoResponse { - match Comment::update(&comment_id, &user.user_id, &update_comment.comment).await { - Ok(comment) => (StatusCode::ACCEPTED, Json(serde_json::json!(comment))), - Err(err_val) => ( + if update_comment.comment.len() > SERVER_CONFIG.comment_length_limit { + return ( StatusCode::BAD_REQUEST, - Json(serde_json::json!(err_val.to_string())), - ), + Json(serde_json::json!(ForumInputError::TooLong)), + ); + } else { + match Comment::update(&comment_id, &user.user_id, &update_comment.comment).await { + Ok(comment) => (StatusCode::ACCEPTED, Json(serde_json::json!(comment))), + Err(err_val) => ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!(err_val.to_string())), + ), + } } } diff --git a/src/routing/login.rs b/src/routing/login.rs index aff266d..aa68aee 100644 --- a/src/routing/login.rs +++ b/src/routing/login.rs @@ -14,8 +14,6 @@ use super::middleware::{ by_authorization_token_then_insert, user_and_token_then_insert, UserAndAuthorizationToken, }; -const CONTACT_EMAIL_DEFAULT_ID: i64 = 0; - #[derive(Debug, Serialize, Deserialize)] struct CreateOneTimePassword { pub user_email: String, @@ -49,17 +47,17 @@ pub fn route() -> Router { by_authorization_token_then_insert, )), ) - .route("/count/users", get(count_all_for_user)) + .route( + "/count/users", + get(count_all_for_user).route_layer(axum::middleware::from_fn( + by_authorization_token_then_insert, + )), + ) } async fn create_one_time_password( Json(create_one_time_password): Json, ) -> impl IntoResponse { - match UserContact::read_for_value( - &CONTACT_EMAIL_DEFAULT_ID, - &create_one_time_password.user_email, - ) - .await - { + match UserContact::read_for_email_value(&create_one_time_password.user_email).await { Ok(user_contact) => match User::read(&user_contact.user_id).await { Ok(user) => { match OneTimePassword::new(&user, &create_one_time_password.user_email).await { @@ -83,7 +81,7 @@ async fn create_one_time_password( } } async fn create(Json(create_login): Json) -> impl IntoResponse { - match UserContact::read_for_value(&CONTACT_EMAIL_DEFAULT_ID, &create_login.user_email).await { + match UserContact::read_for_email_value(&create_login.user_email).await { Ok(user_contact) => match User::read(&user_contact.user_id).await { Ok(user) => { let one_time_password = diff --git a/src/routing/middleware.rs b/src/routing/middleware.rs index a4adb16..5a4eead 100644 --- a/src/routing/middleware.rs +++ b/src/routing/middleware.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::{ - extract::Request, + extract::{OriginalUri, Request}, http::{self, HeaderMap, StatusCode, Uri}, middleware::Next, response::IntoResponse, @@ -64,11 +64,11 @@ async fn user_extraction_from_uri(request_uri: &Uri) -> Result>(); for (index, uri_part) in request_uri_parts.iter().enumerate() { if *uri_part == "users" { - if let Some(user_id) = request_uri_parts.get(index) { - if let Ok(user_id) = (*user_id).parse::() { - User::read(&user_id).await.map_err(|err_val| { + if let Some(user_id) = request_uri_parts.get(index + 1) { + if let Ok(user_id) = user_id.parse::() { + return User::read(&user_id).await.map_err(|err_val| { ForumAuthError::AuthenticationFailed(err_val.to_string()) - })?; + }); } } } @@ -86,10 +86,7 @@ async fn user_from_header_and_target_user_from_uri_extraction( Ok(UserAndTargetUser { user, target_user }) } -pub async fn user_and_token_then_insert( - mut request: Request, - next: Next, -) -> Result { +pub async fn user_and_token_then_insert(mut request: Request, next: Next) -> impl IntoResponse { if let Ok(authorization_token) = authorization_token_extraction(&request.headers()).await { if let Ok(user) = user_extraction_from_authorization_token(&authorization_token).await { let user_and_token = Arc::new(UserAndAuthorizationToken { @@ -98,110 +95,115 @@ pub async fn user_and_token_then_insert( }); request.extensions_mut().insert(user_and_token); - return Ok(next.run(request).await); + return next.run(request).await; } } - Err(StatusCode::FORBIDDEN) + StatusCode::FORBIDDEN.into_response() } -pub async fn by_authorization_token( - request: Request, - next: Next, -) -> Result { +pub async fn by_authorization_token(request: Request, next: Next) -> impl IntoResponse { match user_extraction_from_header(request.headers()).await { - Ok(_) => Ok(next.run(request).await), - Err(_) => Err(StatusCode::FORBIDDEN), + Ok(_) => next.run(request).await, + Err(_) => StatusCode::FORBIDDEN.into_response(), } } pub async fn by_authorization_token_then_insert( mut request: Request, next: Next, -) -> Result { +) -> impl IntoResponse { match user_extraction_from_header(request.headers()).await { Ok(user) => { let user = Arc::new(user); request.extensions_mut().insert(user); - Ok(next.run(request).await) + next.run(request).await } - Err(_) => Err(StatusCode::FORBIDDEN), + Err(_) => StatusCode::FORBIDDEN.into_response(), } } -pub async fn by_uri_then_insert( - mut request: Request, - next: Next, -) -> Result { - if let Ok(target_user) = user_extraction_from_uri(request.uri()).await { +pub async fn by_uri_then_insert(mut request: Request, next: Next) -> impl IntoResponse { + if let Ok(target_user) = user_extraction_from_uri( + request + .extensions() + .get::() + .expect("Shouldn't panic, how we couldn't have uri"), + ) + .await + { let target_user = Arc::new(target_user); request.extensions_mut().insert(target_user); - return Ok(next.run(request).await); + return next.run(request).await; } - Err(StatusCode::BAD_REQUEST) + StatusCode::BAD_REQUEST.into_response() } -pub async fn builder_by_authorization_token( - request: Request, - next: Next, -) -> Result { +pub async fn builder_by_authorization_token(request: Request, next: Next) -> impl IntoResponse { if let Ok(user) = user_extraction_from_header(request.headers()).await { if User::is_builder(&user).await { - return Ok(next.run(request).await); + return next.run(request).await; } } - Err(StatusCode::FORBIDDEN) + StatusCode::FORBIDDEN.into_response() } pub async fn builder_by_authorization_token_then_insert( mut request: Request, next: Next, -) -> Result { +) -> impl IntoResponse { if let Ok(user) = user_extraction_from_header(request.headers()).await { if User::is_builder(&user).await { let user = Arc::new(user); request.extensions_mut().insert(user); - return Ok(next.run(request).await); + return next.run(request).await; } } - Err(StatusCode::FORBIDDEN) + StatusCode::FORBIDDEN.into_response() } pub async fn builder_or_admin_by_authorization_token( request: Request, next: Next, -) -> Result { +) -> impl IntoResponse { if let Ok(user) = user_extraction_from_header(request.headers()).await { if User::is_builder_or_admin(&user).await { - return Ok(next.run(request).await); + return next.run(request).await; } } - Err(StatusCode::FORBIDDEN) + + StatusCode::FORBIDDEN.into_response() } pub async fn builder_or_admin_by_authorization_token_then_insert( mut request: Request, next: Next, -) -> Result { +) -> impl IntoResponse { if let Ok(user) = user_extraction_from_header(request.headers()).await { if User::is_builder_or_admin(&user).await { let user = Arc::new(user); request.extensions_mut().insert(user); - return Ok(next.run(request).await); + return next.run(request).await; } } - Err(StatusCode::FORBIDDEN) + StatusCode::FORBIDDEN.into_response() } pub async fn builder_by_authorization_token_and_target_user_by_uri_then_insert_target( mut request: Request, next: Next, -) -> Result { - if let Ok(user_and_target_user) = - user_from_header_and_target_user_from_uri_extraction(request.headers(), request.uri()).await +) -> impl IntoResponse { + if let Ok(user_and_target_user) = user_from_header_and_target_user_from_uri_extraction( + request.headers(), + request + .extensions() + .get::() + .expect("Shouldn't panic, how we couldn't have uri"), + ) + .await { let user = user_and_target_user.user; let target_user = user_and_target_user.target_user; @@ -210,18 +212,24 @@ pub async fn builder_by_authorization_token_and_target_user_by_uri_then_insert_t let target_user = Arc::new(target_user); request.extensions_mut().insert(target_user); - return Ok(next.run(request).await); + return next.run(request).await; } } - Err(StatusCode::FORBIDDEN) + StatusCode::FORBIDDEN.into_response() } pub async fn builder_or_admin_by_authorization_token_and_target_user_by_uri_then_insert_target( mut request: Request, next: Next, -) -> Result { - if let Ok(user_and_target_user) = - user_from_header_and_target_user_from_uri_extraction(request.headers(), request.uri()).await +) -> impl IntoResponse { + if let Ok(user_and_target_user) = user_from_header_and_target_user_from_uri_extraction( + request.headers(), + request + .extensions() + .get::() + .expect("Shouldn't panic, how we couldn't have uri"), + ) + .await { let user = user_and_target_user.user; let target_user = user_and_target_user.target_user; @@ -230,8 +238,8 @@ pub async fn builder_or_admin_by_authorization_token_and_target_user_by_uri_then let target_user = Arc::new(target_user); request.extensions_mut().insert(target_user); - return Ok(next.run(request).await); + return next.run(request).await; } } - Err(StatusCode::FORBIDDEN) + StatusCode::FORBIDDEN.into_response() } diff --git a/src/routing/post.rs b/src/routing/post.rs index dba8345..8675eef 100644 --- a/src/routing/post.rs +++ b/src/routing/post.rs @@ -9,7 +9,11 @@ use axum::{ }; use serde::{Deserialize, Serialize}; -use crate::feature::{post::Post, user::User}; +use crate::{ + error::ForumInputError, + feature::{post::Post, user::User}, + SERVER_CONFIG, +}; use super::middleware::{by_authorization_token_then_insert, by_uri_then_insert}; @@ -53,12 +57,19 @@ async fn create( Extension(user): Extension>, Json(create_post): Json, ) -> impl IntoResponse { - match Post::create(&user.user_id, &create_post.post).await { - Ok(post) => (StatusCode::CREATED, Json(serde_json::json!(post))), - Err(err_val) => ( + if create_post.post.len() > SERVER_CONFIG.post_length_limit { + return ( StatusCode::BAD_REQUEST, - Json(serde_json::json!(err_val.to_string())), - ), + Json(serde_json::json!(ForumInputError::TooLong)), + ); + } else { + match Post::create(&user.user_id, &create_post.post).await { + Ok(post) => (StatusCode::CREATED, Json(serde_json::json!(post))), + Err(err_val) => ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!(err_val.to_string())), + ), + } } } @@ -77,12 +88,19 @@ async fn update( Path(post_id): Path, Json(update_post): Json, ) -> impl IntoResponse { - match Post::update(&post_id, &user.user_id, &update_post.post).await { - Ok(post) => (StatusCode::ACCEPTED, Json(serde_json::json!(post))), - Err(err_val) => ( + if update_post.post.len() > SERVER_CONFIG.post_length_limit { + return ( StatusCode::BAD_REQUEST, - Json(serde_json::json!(err_val.to_string())), - ), + Json(serde_json::json!(ForumInputError::TooLong)), + ); + } else { + match Post::update(&post_id, &user.user_id, &update_post.post).await { + Ok(post) => (StatusCode::ACCEPTED, Json(serde_json::json!(post))), + Err(err_val) => ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!(err_val.to_string())), + ), + } } }