refactor: ♻️ new permission strategy part 2

This commit is contained in:
Ahmet Kaan GÜMÜŞ 2025-01-19 23:47:09 +03:00
parent 56aa04e32a
commit bcfcd2c6f0
15 changed files with 393 additions and 73 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
debug/ debug/
target/ target/
Cargo.lock Cargo.lock
cspell.json

View file

@ -1,10 +1,10 @@
-- Add up migration script here -- Add up migration script here
CREATE TABLE IF NOT EXISTS "user"( CREATE TABLE IF NOT EXISTS "user"(
id BIGSERIAL PRIMARY KEY NOT NULL UNIQUE, user_id BIGSERIAL PRIMARY KEY NOT NULL UNIQUE,
name VARCHAR(256) NOT NULL, name VARCHAR(256) NOT NULL,
surname VARCHAR(256) NOT NULL, surname VARCHAR(256) NOT NULL,
gender boolean NOT NULL, gender boolean NOT NULL,
birth_date DATE NOT NULL, birth_date DATE NOT NULL,
role_id BIGSERIAL NOT NULL REFERENCES "role"(id), role_id BIGSERIAL NOT NULL REFERENCES "role"(id),
creation_time TIMESTAMPTZ NOT NULL DEFAULT NOW() creation_time TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );

View file

@ -1,6 +1,6 @@
-- Add up migration script here -- Add up migration script here
CREATE TABLE IF NOT EXISTS "post"( CREATE TABLE IF NOT EXISTS "post"(
creation_time TIMESTAMPTZ PRIMARY KEY UNIQUE NOT NULL DEFAULT NOW(), creation_time TIMESTAMPTZ PRIMARY KEY UNIQUE NOT NULL DEFAULT NOW(),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id), user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
post VARCHAR(8192) NOT NULL UNIQUE post VARCHAR(8192) NOT NULL UNIQUE
); );

View file

@ -2,6 +2,6 @@
CREATE TABLE IF NOT EXISTS "comment"( CREATE TABLE IF NOT EXISTS "comment"(
creation_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(), creation_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time), post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id), user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
comment VARCHAR(8192) NOT NULL comment VARCHAR(8192) NOT NULL
); );

View file

@ -2,6 +2,6 @@
CREATE TABLE IF NOT EXISTS "post_interaction"( CREATE TABLE IF NOT EXISTS "post_interaction"(
interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(), interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time), post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id), user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id) interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id)
); );

View file

@ -2,6 +2,6 @@
CREATE TABLE IF NOT EXISTS "comment_interaction"( CREATE TABLE IF NOT EXISTS "comment_interaction"(
interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(), interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
comment_creation_time TIMESTAMPTZ NOT NULL REFERENCES "comment"(creation_time), comment_creation_time TIMESTAMPTZ NOT NULL REFERENCES "comment"(creation_time),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id), user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id) interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id)
); );

View file

@ -1,6 +1,6 @@
-- Add up migration script here -- Add up migration script here
CREATE TABLE IF NOT EXISTS "user_contact"( CREATE TABLE IF NOT EXISTS "user_contact"(
user_id BIGSERIAL NOT NULL REFERENCES "user"(id), user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
contact_id BIGSERIAL NOT NULL REFERENCES "contact"(id), contact_id BIGSERIAL NOT NULL REFERENCES "contact"(id),
PRIMARY KEY (user_id, contact_id) PRIMARY KEY (user_id, contact_id)
); );

View file

@ -1,6 +1,6 @@
-- Add up migration script here -- Add up migration script here
CREATE TABLE IF NOT EXISTS "login" ( CREATE TABLE IF NOT EXISTS "login" (
user_id BIGSERIAL NOT NULL REFERENCES "user" (id), user_id BIGSERIAL NOT NULL REFERENCES "user" (user_id),
token VARCHAR(1024) NOT NULL, token VARCHAR(1024) NOT NULL,
token_creation_time TIMESTAMPTZ NOT NULL DEFAULT NOW (), token_creation_time TIMESTAMPTZ NOT NULL DEFAULT NOW (),
PRIMARY KEY (user_id, token) PRIMARY KEY (user_id, token)

View file

@ -13,8 +13,8 @@ pub async fn create(
sqlx::query_as!( sqlx::query_as!(
User, User,
r#" r#"
INSERT INTO "user"(name, surname, gender, birth_date, role_id) INSERT INTO "user"(name, surname, gender, birth_date, role_id)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
RETURNING * RETURNING *
"#, "#,
name, name,
@ -31,7 +31,7 @@ pub async fn read(id: &i64, database_connection: &Pool<Postgres>) -> Result<User
sqlx::query_as!( sqlx::query_as!(
User, User,
r#" r#"
SELECT * FROM "user" WHERE "id" = $1 SELECT * FROM "user" WHERE "user_id" = $1
"#, "#,
id id
) )
@ -50,7 +50,7 @@ pub async fn update(
) -> Result<User, sqlx::Error> { ) -> Result<User, sqlx::Error> {
sqlx::query_as!(User, sqlx::query_as!(User,
r#" r#"
UPDATE "user" SET "name" = $2, "surname" = $3, "gender" = $4, "birth_date" = $5, "role_id" = $6 WHERE "id" = $1 UPDATE "user" SET "name" = $2, "surname" = $3, "gender" = $4, "birth_date" = $5, "role_id" = $6 WHERE "user_id" = $1
RETURNING * RETURNING *
"#, id, name, surname, gender, birth_date, role_id).fetch_one(database_connection).await "#, id, name, surname, gender, birth_date, role_id).fetch_one(database_connection).await
} }
@ -59,7 +59,7 @@ pub async fn delete(id: &i64, database_connection: &Pool<Postgres>) -> Result<Us
sqlx::query_as!( sqlx::query_as!(
User, User,
r#" r#"
DELETE FROM "user" WHERE "id" = $1 DELETE FROM "user" WHERE "user_id" = $1
RETURNING * RETURNING *
"#, "#,
id id
@ -157,13 +157,13 @@ pub async fn read_all_for_gender(
pub async fn read_all_id(database_connection: &Pool<Postgres>) -> Result<Vec<i64>, sqlx::Error> { pub async fn read_all_id(database_connection: &Pool<Postgres>) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!( Ok(sqlx::query!(
r#" r#"
SELECT "id" FROM "user" SELECT "user_id" FROM "user"
"#, "#,
) )
.fetch_all(database_connection) .fetch_all(database_connection)
.await? .await?
.iter() .iter()
.map(|record| record.id) .map(|record| record.user_id)
.collect::<Vec<i64>>()) .collect::<Vec<i64>>())
} }
@ -173,14 +173,14 @@ pub async fn read_all_id_for_name(
) -> Result<Vec<i64>, sqlx::Error> { ) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!( Ok(sqlx::query!(
r#" r#"
SELECT "id" FROM "user" WHERE "name" = $1 SELECT "user_id" FROM "user" WHERE "name" = $1
"#, "#,
name name
) )
.fetch_all(database_connection) .fetch_all(database_connection)
.await? .await?
.iter() .iter()
.map(|record| record.id) .map(|record| record.user_id)
.collect::<Vec<i64>>()) .collect::<Vec<i64>>())
} }
@ -190,14 +190,14 @@ pub async fn read_all_id_for_surname(
) -> Result<Vec<i64>, sqlx::Error> { ) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!( Ok(sqlx::query!(
r#" r#"
SELECT "id" FROM "user" WHERE "surname" = $1 SELECT "user_id" FROM "user" WHERE "surname" = $1
"#, "#,
surname surname
) )
.fetch_all(database_connection) .fetch_all(database_connection)
.await? .await?
.iter() .iter()
.map(|record| record.id) .map(|record| record.user_id)
.collect::<Vec<i64>>()) .collect::<Vec<i64>>())
} }
@ -207,14 +207,14 @@ pub async fn read_all_id_for_birth_date(
) -> Result<Vec<i64>, sqlx::Error> { ) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!( Ok(sqlx::query!(
r#" r#"
SELECT "id" FROM "user" WHERE "birth_date" = $1 SELECT "user_id" FROM "user" WHERE "birth_date" = $1
"#, "#,
birth_date birth_date
) )
.fetch_all(database_connection) .fetch_all(database_connection)
.await? .await?
.iter() .iter()
.map(|record| record.id) .map(|record| record.user_id)
.collect::<Vec<i64>>()) .collect::<Vec<i64>>())
} }
@ -224,14 +224,14 @@ pub async fn read_all_id_for_role(
) -> Result<Vec<i64>, sqlx::Error> { ) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!( Ok(sqlx::query!(
r#" r#"
SELECT "id" FROM "user" WHERE "role_id" = $1 SELECT "user_id" FROM "user" WHERE "role_id" = $1
"#, "#,
role_id role_id
) )
.fetch_all(database_connection) .fetch_all(database_connection)
.await? .await?
.iter() .iter()
.map(|record| record.id) .map(|record| record.user_id)
.collect::<Vec<i64>>()) .collect::<Vec<i64>>())
} }
@ -241,21 +241,21 @@ pub async fn read_all_id_for_gender(
) -> Result<Vec<i64>, sqlx::Error> { ) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!( Ok(sqlx::query!(
r#" r#"
SELECT "id" FROM "user" WHERE "gender" = $1 SELECT "user_id" FROM "user" WHERE "gender" = $1
"#, "#,
gender gender
) )
.fetch_all(database_connection) .fetch_all(database_connection)
.await? .await?
.iter() .iter()
.map(|record| record.id) .map(|record| record.user_id)
.collect::<Vec<i64>>()) .collect::<Vec<i64>>())
} }
pub async fn count_all(database_connection: &Pool<Postgres>) -> Result<u64, sqlx::Error> { pub async fn count_all(database_connection: &Pool<Postgres>) -> Result<u64, sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
SELECT COUNT(id) FROM "user" SELECT COUNT(user_id) FROM "user"
"#, "#,
) )
.fetch_one(database_connection) .fetch_one(database_connection)
@ -272,7 +272,7 @@ pub async fn count_all_for_name(
) -> Result<u64, sqlx::Error> { ) -> Result<u64, sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
SELECT COUNT(id) FROM "user" WHERE "name" = $1 SELECT COUNT(user_id) FROM "user" WHERE "name" = $1
"#, "#,
name name
) )
@ -290,7 +290,7 @@ pub async fn count_all_for_surname(
) -> Result<u64, sqlx::Error> { ) -> Result<u64, sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
SELECT COUNT(id) FROM "user" WHERE "surname" = $1 SELECT COUNT(user_id) FROM "user" WHERE "surname" = $1
"#, "#,
surname surname
) )
@ -308,7 +308,7 @@ pub async fn count_all_for_birth_date(
) -> Result<u64, sqlx::Error> { ) -> Result<u64, sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
SELECT COUNT(id) FROM "user" WHERE "birth_date" = $1 SELECT COUNT(user_id) FROM "user" WHERE "birth_date" = $1
"#, "#,
birth_date birth_date
) )
@ -326,7 +326,7 @@ pub async fn count_all_for_role(
) -> Result<u64, sqlx::Error> { ) -> Result<u64, sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
SELECT COUNT(id) FROM "user" WHERE "role_id" = $1 SELECT COUNT(user_id) FROM "user" WHERE "role_id" = $1
"#, "#,
role_id role_id
) )
@ -344,7 +344,7 @@ pub async fn count_all_for_gender(
) -> Result<u64, sqlx::Error> { ) -> Result<u64, sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
SELECT COUNT(id) FROM "user" WHERE "gender" = $1 SELECT COUNT(user_id) FROM "user" WHERE "gender" = $1
"#, "#,
gender gender
) )

View file

@ -51,3 +51,22 @@ impl std::error::Error for ForumMailError {
self.source() self.source()
} }
} }
#[derive(Debug, Serialize, Deserialize)]
pub enum ForumAuthError {
TokenRefreshTimeOver,
}
impl std::fmt::Display for ForumAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ForumAuthError::TokenRefreshTimeOver => write!(f, "Token Refresh Time is Over"),
}
}
}
impl std::error::Error for ForumAuthError {
fn cause(&self) -> Option<&dyn std::error::Error> {
self.source()
}
}

View file

@ -7,6 +7,8 @@ use crate::{
ONE_TIME_PASSWORDS, ONE_TIME_PASSWORDS,
}; };
use super::user::User;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct OneTimePassword { pub struct OneTimePassword {
pub user_id: i64, pub user_id: i64,
@ -18,19 +20,15 @@ impl OneTimePassword {
RwLock::new(vec![]) RwLock::new(vec![])
} }
pub async fn new( pub async fn new(user: &User, user_email: &String) -> Result<(), ForumMailError> {
user_id: &i64,
user_name: &String,
user_email: &String,
) -> Result<(), ForumMailError> {
let one_time_password = "123".to_owned(); let one_time_password = "123".to_owned();
let new_self = Self { let new_self = Self {
user_id: *user_id, user_id: user.user_id,
one_time_password, one_time_password,
}; };
let mail_template = let mail_template =
MailTemplate::OneTimePassword(MailFieldsOneTimePassword::new(user_name, &new_self)); MailTemplate::OneTimePassword(MailFieldsOneTimePassword::new(&user.name, &new_self));
mail_template.send_mail(user_email).await?; mail_template.send_mail(user_email).await?;

View file

@ -2,18 +2,20 @@ use std::sync::LazyLock;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use jwt_simple::{ use jwt_simple::{
claims::Claims, claims::{Claims, JWTClaims},
common::VerificationOptions, common::VerificationOptions,
prelude::{HS256Key, MACLike}, prelude::{HS256Key, MACLike},
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use crate::{database::login, SERVER_CONFIG}; use crate::{database::login, error::ForumAuthError, SERVER_CONFIG};
use super::user::User;
static TOKEN_META: LazyLock<TokenMeta> = LazyLock::new(TokenMeta::init); static TOKEN_META: LazyLock<TokenMeta> = LazyLock::new(TokenMeta::init);
struct TokenMeta { pub struct TokenMeta {
token_key: HS256Key, token_key: HS256Key,
token_verification_options: Option<VerificationOptions>, token_verification_options: Option<VerificationOptions>,
} }
@ -30,27 +32,28 @@ impl TokenMeta {
} }
} }
async fn create_token() -> Option<String> { async fn create_token(user_id: &i64) -> Option<String> {
let key = &TOKEN_META.token_key; let key = &TOKEN_META.token_key;
let claims = Claims::create(jwt_simple::prelude::Duration::from_mins(
SERVER_CONFIG.login_token_expiration_time_limit as u64, let claims = Claims::with_custom_claims(
)); *user_id,
jwt_simple::prelude::Duration::from_mins(
SERVER_CONFIG.login_token_expiration_time_limit as u64,
),
);
let token = key.authenticate(claims).unwrap(); let token = key.authenticate(claims).unwrap();
match TokenMeta::verify_token(&token).await { match TokenMeta::verify_token(&token).await {
true => Some(token), Ok(_) => Some(token),
false => None, Err(_) => None,
} }
} }
async fn verify_token(token: &String) -> bool { pub async fn verify_token(token: &String) -> Result<JWTClaims<i64>, jwt_simple::Error> {
let token_meta = &TOKEN_META; let token_meta = &TOKEN_META;
token_meta token_meta
.token_key .token_key
.verify_token::<jwt_simple::prelude::NoCustomClaims>( .verify_token::<i64>(token, token_meta.token_verification_options.clone())
token,
token_meta.token_verification_options.clone(),
)
.is_ok()
} }
} }
@ -66,7 +69,9 @@ impl Login {
user_id: &i64, user_id: &i64,
database_connection: &Pool<Postgres>, database_connection: &Pool<Postgres>,
) -> Result<Login, sqlx::Error> { ) -> Result<Login, sqlx::Error> {
let token = TokenMeta::create_token() User::read(user_id, database_connection).await?;
let token = TokenMeta::create_token(user_id)
.await .await
.expect("Should not panic if it isn't configured wrong"); .expect("Should not panic if it isn't configured wrong");
login::create(user_id, &token, database_connection).await login::create(user_id, &token, database_connection).await
@ -77,6 +82,8 @@ impl Login {
token: &String, token: &String,
database_connection: &Pool<Postgres>, database_connection: &Pool<Postgres>,
) -> Result<Login, sqlx::Error> { ) -> Result<Login, sqlx::Error> {
User::read(user_id, database_connection).await?;
login::read(user_id, token, database_connection).await login::read(user_id, token, database_connection).await
} }
@ -84,21 +91,21 @@ impl Login {
user_id: &i64, user_id: &i64,
token: &String, token: &String,
database_connection: &Pool<Postgres>, database_connection: &Pool<Postgres>,
) -> Result<Login, sqlx::Error> { ) -> Result<Login, Box<dyn std::error::Error>> {
let login = Login::read(user_id, token, database_connection).await?; let login = Login::read(user_id, token, database_connection).await?;
match TokenMeta::verify_token(token).await { match TokenMeta::verify_token(token).await {
true => Ok(login), Ok(_) => Ok(login),
false => { Err(_) => {
if DateTime::<Utc>::default() if DateTime::<Utc>::default()
.signed_duration_since(&login.token_creation_time) .signed_duration_since(&login.token_creation_time)
.num_minutes() .num_minutes()
<= SERVER_CONFIG.login_token_refresh_time_limit as i64 <= SERVER_CONFIG.login_token_refresh_time_limit as i64
{ {
Login::delete(user_id, token, database_connection).await?; Login::delete(user_id, token, database_connection).await?;
Login::create(user_id, database_connection).await let login = Login::create(user_id, database_connection).await?;
} else {
Ok(login) Ok(login)
} else {
Err(Box::new(ForumAuthError::TokenRefreshTimeOver))
} }
} }
} }

View file

@ -13,7 +13,7 @@ pub struct Contact {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct User { pub struct User {
pub id: i64, pub user_id: i64,
pub name: String, pub name: String,
pub surname: String, pub surname: String,
pub gender: bool, pub gender: bool,
@ -33,12 +33,15 @@ impl User {
user::create(name, surname, gender, birth_date, database_connection).await user::create(name, surname, gender, birth_date, database_connection).await
} }
pub async fn read(id: &i64, database_connection: &Pool<Postgres>) -> Result<User, sqlx::Error> { pub async fn read(
user::read(id, database_connection).await user_id: &i64,
database_connection: &Pool<Postgres>,
) -> Result<User, sqlx::Error> {
user::read(user_id, database_connection).await
} }
pub async fn update( pub async fn update(
id: &i64, user_id: &i64,
name: &String, name: &String,
surname: &String, surname: &String,
gender: &bool, gender: &bool,
@ -47,7 +50,7 @@ impl User {
database_connection: &Pool<Postgres>, database_connection: &Pool<Postgres>,
) -> Result<User, sqlx::Error> { ) -> Result<User, sqlx::Error> {
user::update( user::update(
id, user_id,
name, name,
surname, surname,
gender, gender,
@ -59,10 +62,10 @@ impl User {
} }
pub async fn delete( pub async fn delete(
id: &i64, user_id: &i64,
database_connection: &Pool<Postgres>, database_connection: &Pool<Postgres>,
) -> Result<User, sqlx::Error> { ) -> Result<User, sqlx::Error> {
user::delete(id, database_connection).await user::delete(user_id, database_connection).await
} }
pub async fn read_all(database_connection: &Pool<Postgres>) -> Result<Vec<User>, sqlx::Error> { pub async fn read_all(database_connection: &Pool<Postgres>) -> Result<Vec<User>, sqlx::Error> {
@ -183,4 +186,62 @@ impl User {
) -> Result<u64, sqlx::Error> { ) -> Result<u64, sqlx::Error> {
user::count_all_for_gender(gender, database_connection).await user::count_all_for_gender(gender, database_connection).await
} }
pub async fn is_builder(user: &User) -> bool {
if user.role_id == 0 {
true
} else {
false
}
}
pub async fn is_admin(user: &User) -> bool {
if user.role_id == 1 {
true
} else {
false
}
}
pub async fn is_banned(user: &User) -> bool {
if user.role_id == -1 {
true
} else {
false
}
}
pub async fn is_builder_or_admin(user: &User) -> bool {
if user.role_id == 0 || user.role_id == 1 {
true
} else {
false
}
}
pub async fn is_self(user: &User, target_user: &User) -> bool {
if user.user_id == target_user.user_id {
true
} else {
false
}
}
pub async fn is_higher(user: &User, target_user: &User) -> bool {
if user.user_id >= 0 {
if user.user_id < target_user.user_id {
return true;
}
}
false
}
pub async fn is_higher_or_self(user: &User, target_user: &User) -> bool {
if User::is_self(user, target_user).await {
true
} else {
User::is_higher(user, target_user).await
}
}
} }

View file

@ -3,6 +3,7 @@ pub mod comment_interaction;
pub mod contact; pub mod contact;
pub mod interaction; pub mod interaction;
pub mod login; pub mod login;
pub mod middleware;
pub mod post; pub mod post;
pub mod post_interaction; pub mod post_interaction;
pub mod role; pub mod role;
@ -18,6 +19,10 @@ use crate::{database, AppState};
pub async fn route(concurrency_limit: &usize, State(app_state): State<AppState>) -> Router { pub async fn route(concurrency_limit: &usize, State(app_state): State<AppState>) -> Router {
Router::new() Router::new()
.route("/", get(alive)) .route("/", get(alive))
.route_layer(axum::middleware::from_fn_with_state(
app_state.clone(),
middleware::pass,
))
.nest( .nest(
"/roles", "/roles",
role::route(axum::extract::State(app_state.clone())), role::route(axum::extract::State(app_state.clone())),
@ -59,7 +64,7 @@ pub async fn route(concurrency_limit: &usize, State(app_state): State<AppState>)
.with_state(app_state) .with_state(app_state)
} }
async fn alive(State(app_state): State<AppState>) -> impl IntoResponse { pub async fn alive(State(app_state): State<AppState>) -> impl IntoResponse {
match database::is_alive(&app_state.database_connection).await { match database::is_alive(&app_state.database_connection).await {
true => StatusCode::OK, true => StatusCode::OK,
false => StatusCode::SERVICE_UNAVAILABLE, false => StatusCode::SERVICE_UNAVAILABLE,

229
src/routing/middleware.rs Normal file
View file

@ -0,0 +1,229 @@
use std::sync::Arc;
use axum::{
body::{to_bytes, Body},
extract::{Request, State},
http::{self, StatusCode},
middleware::Next,
response::IntoResponse,
};
use sqlx::{Pool, Postgres};
use crate::{
feature::{login::TokenMeta, user::User},
AppState,
};
#[derive(Debug)]
struct UserAndRequest {
user: User,
request: Request,
}
#[derive(Debug)]
struct UserAndTargetUserAndRequest {
user: User,
target_user: User,
request: Request,
}
async fn user_extraction(
request: Request,
database_connection: &Pool<Postgres>,
) -> Option<UserAndRequest> {
if let Some(authorization_header) = request.headers().get(http::header::AUTHORIZATION) {
if let Ok(authorization_header) = authorization_header.to_str() {
if let Some((bearer, authorization_header)) = authorization_header.split_once(' ') {
if bearer == "bearer" {
if let Ok(claims) =
TokenMeta::verify_token(&authorization_header.to_string()).await
{
return Some(UserAndRequest {
user: User::read(&claims.custom, database_connection).await.ok()?,
request,
});
}
}
}
}
}
None
}
async fn target_user_extraction_from_json(
json: &serde_json::Value,
database_connection: &Pool<Postgres>,
) -> Option<User> {
if let Some(target_user_id) = json.get("user_id") {
if target_user_id.is_i64() {
if let Some(target_user_id) = target_user_id.as_i64() {
return User::read(&target_user_id, database_connection).await.ok();
}
}
}
None
}
async fn user_and_target_user_extraction(
request: Request,
database_connection: &Pool<Postgres>,
) -> Option<UserAndTargetUserAndRequest> {
let user_and_request = user_extraction(request, database_connection).await?;
let user = user_and_request.user;
let request = user_and_request.request;
let (parts, body) = request.into_parts();
let bytes = to_bytes(body, usize::MAX).await.ok()?;
let json: serde_json::Value = serde_json::from_slice(&bytes).ok()?;
let body = Body::from(json.to_string());
let request = Request::from_parts(parts, body);
Some(UserAndTargetUserAndRequest {
user,
target_user: target_user_extraction_from_json(&json, database_connection).await?,
request,
})
}
pub async fn pass(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
match user_extraction(request, &app_state.database_connection).await {
Some(user_and_request) => {
let user = Arc::new(user_and_request.user);
let mut request = user_and_request.request;
request.extensions_mut().insert(user);
Ok(next.run(request).await)
}
None => Err(StatusCode::FORBIDDEN),
}
}
pub async fn pass_builder(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
let user = user_and_request.user;
let mut request = user_and_request.request;
if User::is_builder(&user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_admin(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
let user = user_and_request.user;
let mut request = user_and_request.request;
if User::is_admin(&user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_builder_or_admin(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
let user = user_and_request.user;
let mut request = user_and_request.request;
if User::is_builder_or_admin(&user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_self(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_target_user_and_request) =
user_and_target_user_extraction(request, &app_state.database_connection).await
{
let user = user_and_target_user_and_request.user;
let target_user = user_and_target_user_and_request.target_user;
let mut request = user_and_target_user_and_request.request;
if User::is_self(&user, &target_user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_higher(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_target_user_and_request) =
user_and_target_user_extraction(request, &app_state.database_connection).await
{
let user = user_and_target_user_and_request.user;
let target_user = user_and_target_user_and_request.target_user;
let mut request = user_and_target_user_and_request.request;
if User::is_higher(&user, &target_user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_higher_or_self(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_target_user_and_request) =
user_and_target_user_extraction(request, &app_state.database_connection).await
{
let user = user_and_target_user_and_request.user;
let target_user = user_and_target_user_and_request.target_user;
let mut request = user_and_target_user_and_request.request;
if User::is_higher_or_self(&user, &target_user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}