refactor: ♻️ new permission strategy part 2
This commit is contained in:
parent
56aa04e32a
commit
bcfcd2c6f0
15 changed files with 393 additions and 73 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,3 +2,4 @@
|
|||
debug/
|
||||
target/
|
||||
Cargo.lock
|
||||
cspell.json
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
-- Add up migration script here
|
||||
CREATE TABLE IF NOT EXISTS "user"(
|
||||
id BIGSERIAL PRIMARY KEY NOT NULL UNIQUE,
|
||||
user_id BIGSERIAL PRIMARY KEY NOT NULL UNIQUE,
|
||||
name VARCHAR(256) NOT NULL,
|
||||
surname VARCHAR(256) NOT NULL,
|
||||
gender boolean NOT NULL,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
-- Add up migration script here
|
||||
CREATE TABLE IF NOT EXISTS "post"(
|
||||
creation_time TIMESTAMPTZ PRIMARY KEY UNIQUE NOT NULL DEFAULT NOW(),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
|
||||
post VARCHAR(8192) NOT NULL UNIQUE
|
||||
);
|
|
@ -2,6 +2,6 @@
|
|||
CREATE TABLE IF NOT EXISTS "comment"(
|
||||
creation_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
|
||||
post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
|
||||
comment VARCHAR(8192) NOT NULL
|
||||
);
|
|
@ -2,6 +2,6 @@
|
|||
CREATE TABLE IF NOT EXISTS "post_interaction"(
|
||||
interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
|
||||
post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
|
||||
interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id)
|
||||
);
|
|
@ -2,6 +2,6 @@
|
|||
CREATE TABLE IF NOT EXISTS "comment_interaction"(
|
||||
interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
|
||||
comment_creation_time TIMESTAMPTZ NOT NULL REFERENCES "comment"(creation_time),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
|
||||
interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id)
|
||||
);
|
|
@ -1,6 +1,6 @@
|
|||
-- Add up migration script here
|
||||
CREATE TABLE IF NOT EXISTS "user_contact"(
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
|
||||
contact_id BIGSERIAL NOT NULL REFERENCES "contact"(id),
|
||||
PRIMARY KEY (user_id, contact_id)
|
||||
);
|
|
@ -1,6 +1,6 @@
|
|||
-- Add up migration script here
|
||||
CREATE TABLE IF NOT EXISTS "login" (
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user" (id),
|
||||
user_id BIGSERIAL NOT NULL REFERENCES "user" (user_id),
|
||||
token VARCHAR(1024) NOT NULL,
|
||||
token_creation_time TIMESTAMPTZ NOT NULL DEFAULT NOW (),
|
||||
PRIMARY KEY (user_id, token)
|
||||
|
|
|
@ -31,7 +31,7 @@ pub async fn read(id: &i64, database_connection: &Pool<Postgres>) -> Result<User
|
|||
sqlx::query_as!(
|
||||
User,
|
||||
r#"
|
||||
SELECT * FROM "user" WHERE "id" = $1
|
||||
SELECT * FROM "user" WHERE "user_id" = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
|
@ -50,7 +50,7 @@ pub async fn update(
|
|||
) -> Result<User, sqlx::Error> {
|
||||
sqlx::query_as!(User,
|
||||
r#"
|
||||
UPDATE "user" SET "name" = $2, "surname" = $3, "gender" = $4, "birth_date" = $5, "role_id" = $6 WHERE "id" = $1
|
||||
UPDATE "user" SET "name" = $2, "surname" = $3, "gender" = $4, "birth_date" = $5, "role_id" = $6 WHERE "user_id" = $1
|
||||
RETURNING *
|
||||
"#, id, name, surname, gender, birth_date, role_id).fetch_one(database_connection).await
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ pub async fn delete(id: &i64, database_connection: &Pool<Postgres>) -> Result<Us
|
|||
sqlx::query_as!(
|
||||
User,
|
||||
r#"
|
||||
DELETE FROM "user" WHERE "id" = $1
|
||||
DELETE FROM "user" WHERE "user_id" = $1
|
||||
RETURNING *
|
||||
"#,
|
||||
id
|
||||
|
@ -157,13 +157,13 @@ pub async fn read_all_for_gender(
|
|||
pub async fn read_all_id(database_connection: &Pool<Postgres>) -> Result<Vec<i64>, sqlx::Error> {
|
||||
Ok(sqlx::query!(
|
||||
r#"
|
||||
SELECT "id" FROM "user"
|
||||
SELECT "user_id" FROM "user"
|
||||
"#,
|
||||
)
|
||||
.fetch_all(database_connection)
|
||||
.await?
|
||||
.iter()
|
||||
.map(|record| record.id)
|
||||
.map(|record| record.user_id)
|
||||
.collect::<Vec<i64>>())
|
||||
}
|
||||
|
||||
|
@ -173,14 +173,14 @@ pub async fn read_all_id_for_name(
|
|||
) -> Result<Vec<i64>, sqlx::Error> {
|
||||
Ok(sqlx::query!(
|
||||
r#"
|
||||
SELECT "id" FROM "user" WHERE "name" = $1
|
||||
SELECT "user_id" FROM "user" WHERE "name" = $1
|
||||
"#,
|
||||
name
|
||||
)
|
||||
.fetch_all(database_connection)
|
||||
.await?
|
||||
.iter()
|
||||
.map(|record| record.id)
|
||||
.map(|record| record.user_id)
|
||||
.collect::<Vec<i64>>())
|
||||
}
|
||||
|
||||
|
@ -190,14 +190,14 @@ pub async fn read_all_id_for_surname(
|
|||
) -> Result<Vec<i64>, sqlx::Error> {
|
||||
Ok(sqlx::query!(
|
||||
r#"
|
||||
SELECT "id" FROM "user" WHERE "surname" = $1
|
||||
SELECT "user_id" FROM "user" WHERE "surname" = $1
|
||||
"#,
|
||||
surname
|
||||
)
|
||||
.fetch_all(database_connection)
|
||||
.await?
|
||||
.iter()
|
||||
.map(|record| record.id)
|
||||
.map(|record| record.user_id)
|
||||
.collect::<Vec<i64>>())
|
||||
}
|
||||
|
||||
|
@ -207,14 +207,14 @@ pub async fn read_all_id_for_birth_date(
|
|||
) -> Result<Vec<i64>, sqlx::Error> {
|
||||
Ok(sqlx::query!(
|
||||
r#"
|
||||
SELECT "id" FROM "user" WHERE "birth_date" = $1
|
||||
SELECT "user_id" FROM "user" WHERE "birth_date" = $1
|
||||
"#,
|
||||
birth_date
|
||||
)
|
||||
.fetch_all(database_connection)
|
||||
.await?
|
||||
.iter()
|
||||
.map(|record| record.id)
|
||||
.map(|record| record.user_id)
|
||||
.collect::<Vec<i64>>())
|
||||
}
|
||||
|
||||
|
@ -224,14 +224,14 @@ pub async fn read_all_id_for_role(
|
|||
) -> Result<Vec<i64>, sqlx::Error> {
|
||||
Ok(sqlx::query!(
|
||||
r#"
|
||||
SELECT "id" FROM "user" WHERE "role_id" = $1
|
||||
SELECT "user_id" FROM "user" WHERE "role_id" = $1
|
||||
"#,
|
||||
role_id
|
||||
)
|
||||
.fetch_all(database_connection)
|
||||
.await?
|
||||
.iter()
|
||||
.map(|record| record.id)
|
||||
.map(|record| record.user_id)
|
||||
.collect::<Vec<i64>>())
|
||||
}
|
||||
|
||||
|
@ -241,21 +241,21 @@ pub async fn read_all_id_for_gender(
|
|||
) -> Result<Vec<i64>, sqlx::Error> {
|
||||
Ok(sqlx::query!(
|
||||
r#"
|
||||
SELECT "id" FROM "user" WHERE "gender" = $1
|
||||
SELECT "user_id" FROM "user" WHERE "gender" = $1
|
||||
"#,
|
||||
gender
|
||||
)
|
||||
.fetch_all(database_connection)
|
||||
.await?
|
||||
.iter()
|
||||
.map(|record| record.id)
|
||||
.map(|record| record.user_id)
|
||||
.collect::<Vec<i64>>())
|
||||
}
|
||||
|
||||
pub async fn count_all(database_connection: &Pool<Postgres>) -> Result<u64, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT COUNT(id) FROM "user"
|
||||
SELECT COUNT(user_id) FROM "user"
|
||||
"#,
|
||||
)
|
||||
.fetch_one(database_connection)
|
||||
|
@ -272,7 +272,7 @@ pub async fn count_all_for_name(
|
|||
) -> Result<u64, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT COUNT(id) FROM "user" WHERE "name" = $1
|
||||
SELECT COUNT(user_id) FROM "user" WHERE "name" = $1
|
||||
"#,
|
||||
name
|
||||
)
|
||||
|
@ -290,7 +290,7 @@ pub async fn count_all_for_surname(
|
|||
) -> Result<u64, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT COUNT(id) FROM "user" WHERE "surname" = $1
|
||||
SELECT COUNT(user_id) FROM "user" WHERE "surname" = $1
|
||||
"#,
|
||||
surname
|
||||
)
|
||||
|
@ -308,7 +308,7 @@ pub async fn count_all_for_birth_date(
|
|||
) -> Result<u64, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT COUNT(id) FROM "user" WHERE "birth_date" = $1
|
||||
SELECT COUNT(user_id) FROM "user" WHERE "birth_date" = $1
|
||||
"#,
|
||||
birth_date
|
||||
)
|
||||
|
@ -326,7 +326,7 @@ pub async fn count_all_for_role(
|
|||
) -> Result<u64, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT COUNT(id) FROM "user" WHERE "role_id" = $1
|
||||
SELECT COUNT(user_id) FROM "user" WHERE "role_id" = $1
|
||||
"#,
|
||||
role_id
|
||||
)
|
||||
|
@ -344,7 +344,7 @@ pub async fn count_all_for_gender(
|
|||
) -> Result<u64, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT COUNT(id) FROM "user" WHERE "gender" = $1
|
||||
SELECT COUNT(user_id) FROM "user" WHERE "gender" = $1
|
||||
"#,
|
||||
gender
|
||||
)
|
||||
|
|
19
src/error.rs
19
src/error.rs
|
@ -51,3 +51,22 @@ impl std::error::Error for ForumMailError {
|
|||
self.source()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum ForumAuthError {
|
||||
TokenRefreshTimeOver,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ForumAuthError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ForumAuthError::TokenRefreshTimeOver => write!(f, "Token Refresh Time is Over"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ForumAuthError {
|
||||
fn cause(&self) -> Option<&dyn std::error::Error> {
|
||||
self.source()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ use crate::{
|
|||
ONE_TIME_PASSWORDS,
|
||||
};
|
||||
|
||||
use super::user::User;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct OneTimePassword {
|
||||
pub user_id: i64,
|
||||
|
@ -18,19 +20,15 @@ impl OneTimePassword {
|
|||
RwLock::new(vec![])
|
||||
}
|
||||
|
||||
pub async fn new(
|
||||
user_id: &i64,
|
||||
user_name: &String,
|
||||
user_email: &String,
|
||||
) -> Result<(), ForumMailError> {
|
||||
pub async fn new(user: &User, user_email: &String) -> Result<(), ForumMailError> {
|
||||
let one_time_password = "123".to_owned();
|
||||
let new_self = Self {
|
||||
user_id: *user_id,
|
||||
user_id: user.user_id,
|
||||
one_time_password,
|
||||
};
|
||||
|
||||
let mail_template =
|
||||
MailTemplate::OneTimePassword(MailFieldsOneTimePassword::new(user_name, &new_self));
|
||||
MailTemplate::OneTimePassword(MailFieldsOneTimePassword::new(&user.name, &new_self));
|
||||
|
||||
mail_template.send_mail(user_email).await?;
|
||||
|
||||
|
|
|
@ -2,18 +2,20 @@ use std::sync::LazyLock;
|
|||
|
||||
use chrono::{DateTime, Utc};
|
||||
use jwt_simple::{
|
||||
claims::Claims,
|
||||
claims::{Claims, JWTClaims},
|
||||
common::VerificationOptions,
|
||||
prelude::{HS256Key, MACLike},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Pool, Postgres};
|
||||
|
||||
use crate::{database::login, SERVER_CONFIG};
|
||||
use crate::{database::login, error::ForumAuthError, SERVER_CONFIG};
|
||||
|
||||
use super::user::User;
|
||||
|
||||
static TOKEN_META: LazyLock<TokenMeta> = LazyLock::new(TokenMeta::init);
|
||||
|
||||
struct TokenMeta {
|
||||
pub struct TokenMeta {
|
||||
token_key: HS256Key,
|
||||
token_verification_options: Option<VerificationOptions>,
|
||||
}
|
||||
|
@ -30,27 +32,28 @@ impl TokenMeta {
|
|||
}
|
||||
}
|
||||
|
||||
async fn create_token() -> Option<String> {
|
||||
async fn create_token(user_id: &i64) -> Option<String> {
|
||||
let key = &TOKEN_META.token_key;
|
||||
let claims = Claims::create(jwt_simple::prelude::Duration::from_mins(
|
||||
|
||||
let claims = Claims::with_custom_claims(
|
||||
*user_id,
|
||||
jwt_simple::prelude::Duration::from_mins(
|
||||
SERVER_CONFIG.login_token_expiration_time_limit as u64,
|
||||
));
|
||||
),
|
||||
);
|
||||
|
||||
let token = key.authenticate(claims).unwrap();
|
||||
match TokenMeta::verify_token(&token).await {
|
||||
true => Some(token),
|
||||
false => None,
|
||||
Ok(_) => Some(token),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn verify_token(token: &String) -> bool {
|
||||
pub async fn verify_token(token: &String) -> Result<JWTClaims<i64>, jwt_simple::Error> {
|
||||
let token_meta = &TOKEN_META;
|
||||
token_meta
|
||||
.token_key
|
||||
.verify_token::<jwt_simple::prelude::NoCustomClaims>(
|
||||
token,
|
||||
token_meta.token_verification_options.clone(),
|
||||
)
|
||||
.is_ok()
|
||||
.verify_token::<i64>(token, token_meta.token_verification_options.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,7 +69,9 @@ impl Login {
|
|||
user_id: &i64,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Result<Login, sqlx::Error> {
|
||||
let token = TokenMeta::create_token()
|
||||
User::read(user_id, database_connection).await?;
|
||||
|
||||
let token = TokenMeta::create_token(user_id)
|
||||
.await
|
||||
.expect("Should not panic if it isn't configured wrong");
|
||||
login::create(user_id, &token, database_connection).await
|
||||
|
@ -77,6 +82,8 @@ impl Login {
|
|||
token: &String,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Result<Login, sqlx::Error> {
|
||||
User::read(user_id, database_connection).await?;
|
||||
|
||||
login::read(user_id, token, database_connection).await
|
||||
}
|
||||
|
||||
|
@ -84,21 +91,21 @@ impl Login {
|
|||
user_id: &i64,
|
||||
token: &String,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Result<Login, sqlx::Error> {
|
||||
) -> Result<Login, Box<dyn std::error::Error>> {
|
||||
let login = Login::read(user_id, token, database_connection).await?;
|
||||
|
||||
match TokenMeta::verify_token(token).await {
|
||||
true => Ok(login),
|
||||
false => {
|
||||
Ok(_) => Ok(login),
|
||||
Err(_) => {
|
||||
if DateTime::<Utc>::default()
|
||||
.signed_duration_since(&login.token_creation_time)
|
||||
.num_minutes()
|
||||
<= SERVER_CONFIG.login_token_refresh_time_limit as i64
|
||||
{
|
||||
Login::delete(user_id, token, database_connection).await?;
|
||||
Login::create(user_id, database_connection).await
|
||||
} else {
|
||||
let login = Login::create(user_id, database_connection).await?;
|
||||
Ok(login)
|
||||
} else {
|
||||
Err(Box::new(ForumAuthError::TokenRefreshTimeOver))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ pub struct Contact {
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub id: i64,
|
||||
pub user_id: i64,
|
||||
pub name: String,
|
||||
pub surname: String,
|
||||
pub gender: bool,
|
||||
|
@ -33,12 +33,15 @@ impl User {
|
|||
user::create(name, surname, gender, birth_date, database_connection).await
|
||||
}
|
||||
|
||||
pub async fn read(id: &i64, database_connection: &Pool<Postgres>) -> Result<User, sqlx::Error> {
|
||||
user::read(id, database_connection).await
|
||||
pub async fn read(
|
||||
user_id: &i64,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
user::read(user_id, database_connection).await
|
||||
}
|
||||
|
||||
pub async fn update(
|
||||
id: &i64,
|
||||
user_id: &i64,
|
||||
name: &String,
|
||||
surname: &String,
|
||||
gender: &bool,
|
||||
|
@ -47,7 +50,7 @@ impl User {
|
|||
database_connection: &Pool<Postgres>,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
user::update(
|
||||
id,
|
||||
user_id,
|
||||
name,
|
||||
surname,
|
||||
gender,
|
||||
|
@ -59,10 +62,10 @@ impl User {
|
|||
}
|
||||
|
||||
pub async fn delete(
|
||||
id: &i64,
|
||||
user_id: &i64,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
user::delete(id, database_connection).await
|
||||
user::delete(user_id, database_connection).await
|
||||
}
|
||||
|
||||
pub async fn read_all(database_connection: &Pool<Postgres>) -> Result<Vec<User>, sqlx::Error> {
|
||||
|
@ -183,4 +186,62 @@ impl User {
|
|||
) -> Result<u64, sqlx::Error> {
|
||||
user::count_all_for_gender(gender, database_connection).await
|
||||
}
|
||||
|
||||
pub async fn is_builder(user: &User) -> bool {
|
||||
if user.role_id == 0 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_admin(user: &User) -> bool {
|
||||
if user.role_id == 1 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_banned(user: &User) -> bool {
|
||||
if user.role_id == -1 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_builder_or_admin(user: &User) -> bool {
|
||||
if user.role_id == 0 || user.role_id == 1 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_self(user: &User, target_user: &User) -> bool {
|
||||
if user.user_id == target_user.user_id {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_higher(user: &User, target_user: &User) -> bool {
|
||||
if user.user_id >= 0 {
|
||||
if user.user_id < target_user.user_id {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub async fn is_higher_or_self(user: &User, target_user: &User) -> bool {
|
||||
if User::is_self(user, target_user).await {
|
||||
true
|
||||
} else {
|
||||
User::is_higher(user, target_user).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ pub mod comment_interaction;
|
|||
pub mod contact;
|
||||
pub mod interaction;
|
||||
pub mod login;
|
||||
pub mod middleware;
|
||||
pub mod post;
|
||||
pub mod post_interaction;
|
||||
pub mod role;
|
||||
|
@ -18,6 +19,10 @@ use crate::{database, AppState};
|
|||
pub async fn route(concurrency_limit: &usize, State(app_state): State<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/", get(alive))
|
||||
.route_layer(axum::middleware::from_fn_with_state(
|
||||
app_state.clone(),
|
||||
middleware::pass,
|
||||
))
|
||||
.nest(
|
||||
"/roles",
|
||||
role::route(axum::extract::State(app_state.clone())),
|
||||
|
@ -59,7 +64,7 @@ pub async fn route(concurrency_limit: &usize, State(app_state): State<AppState>)
|
|||
.with_state(app_state)
|
||||
}
|
||||
|
||||
async fn alive(State(app_state): State<AppState>) -> impl IntoResponse {
|
||||
pub async fn alive(State(app_state): State<AppState>) -> impl IntoResponse {
|
||||
match database::is_alive(&app_state.database_connection).await {
|
||||
true => StatusCode::OK,
|
||||
false => StatusCode::SERVICE_UNAVAILABLE,
|
||||
|
|
229
src/routing/middleware.rs
Normal file
229
src/routing/middleware.rs
Normal file
|
@ -0,0 +1,229 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::{to_bytes, Body},
|
||||
extract::{Request, State},
|
||||
http::{self, StatusCode},
|
||||
middleware::Next,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use sqlx::{Pool, Postgres};
|
||||
|
||||
use crate::{
|
||||
feature::{login::TokenMeta, user::User},
|
||||
AppState,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UserAndRequest {
|
||||
user: User,
|
||||
request: Request,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UserAndTargetUserAndRequest {
|
||||
user: User,
|
||||
target_user: User,
|
||||
request: Request,
|
||||
}
|
||||
|
||||
async fn user_extraction(
|
||||
request: Request,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Option<UserAndRequest> {
|
||||
if let Some(authorization_header) = request.headers().get(http::header::AUTHORIZATION) {
|
||||
if let Ok(authorization_header) = authorization_header.to_str() {
|
||||
if let Some((bearer, authorization_header)) = authorization_header.split_once(' ') {
|
||||
if bearer == "bearer" {
|
||||
if let Ok(claims) =
|
||||
TokenMeta::verify_token(&authorization_header.to_string()).await
|
||||
{
|
||||
return Some(UserAndRequest {
|
||||
user: User::read(&claims.custom, database_connection).await.ok()?,
|
||||
request,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn target_user_extraction_from_json(
|
||||
json: &serde_json::Value,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Option<User> {
|
||||
if let Some(target_user_id) = json.get("user_id") {
|
||||
if target_user_id.is_i64() {
|
||||
if let Some(target_user_id) = target_user_id.as_i64() {
|
||||
return User::read(&target_user_id, database_connection).await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
async fn user_and_target_user_extraction(
|
||||
request: Request,
|
||||
database_connection: &Pool<Postgres>,
|
||||
) -> Option<UserAndTargetUserAndRequest> {
|
||||
let user_and_request = user_extraction(request, database_connection).await?;
|
||||
let user = user_and_request.user;
|
||||
let request = user_and_request.request;
|
||||
|
||||
let (parts, body) = request.into_parts();
|
||||
let bytes = to_bytes(body, usize::MAX).await.ok()?;
|
||||
let json: serde_json::Value = serde_json::from_slice(&bytes).ok()?;
|
||||
|
||||
let body = Body::from(json.to_string());
|
||||
let request = Request::from_parts(parts, body);
|
||||
|
||||
Some(UserAndTargetUserAndRequest {
|
||||
user,
|
||||
target_user: target_user_extraction_from_json(&json, database_connection).await?,
|
||||
request,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn pass(
|
||||
State(app_state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
match user_extraction(request, &app_state.database_connection).await {
|
||||
Some(user_and_request) => {
|
||||
let user = Arc::new(user_and_request.user);
|
||||
let mut request = user_and_request.request;
|
||||
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
None => Err(StatusCode::FORBIDDEN),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn pass_builder(
|
||||
State(app_state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
|
||||
let user = user_and_request.user;
|
||||
let mut request = user_and_request.request;
|
||||
|
||||
if User::is_builder(&user).await {
|
||||
let user = Arc::new(user);
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
}
|
||||
|
||||
pub async fn pass_admin(
|
||||
State(app_state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
|
||||
let user = user_and_request.user;
|
||||
let mut request = user_and_request.request;
|
||||
|
||||
if User::is_admin(&user).await {
|
||||
let user = Arc::new(user);
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
}
|
||||
|
||||
pub async fn pass_builder_or_admin(
|
||||
State(app_state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
|
||||
let user = user_and_request.user;
|
||||
let mut request = user_and_request.request;
|
||||
|
||||
if User::is_builder_or_admin(&user).await {
|
||||
let user = Arc::new(user);
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
}
|
||||
|
||||
pub async fn pass_self(
|
||||
State(app_state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if let Some(user_and_target_user_and_request) =
|
||||
user_and_target_user_extraction(request, &app_state.database_connection).await
|
||||
{
|
||||
let user = user_and_target_user_and_request.user;
|
||||
let target_user = user_and_target_user_and_request.target_user;
|
||||
let mut request = user_and_target_user_and_request.request;
|
||||
|
||||
if User::is_self(&user, &target_user).await {
|
||||
let user = Arc::new(user);
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
}
|
||||
|
||||
pub async fn pass_higher(
|
||||
State(app_state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if let Some(user_and_target_user_and_request) =
|
||||
user_and_target_user_extraction(request, &app_state.database_connection).await
|
||||
{
|
||||
let user = user_and_target_user_and_request.user;
|
||||
let target_user = user_and_target_user_and_request.target_user;
|
||||
let mut request = user_and_target_user_and_request.request;
|
||||
|
||||
if User::is_higher(&user, &target_user).await {
|
||||
let user = Arc::new(user);
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
}
|
||||
|
||||
pub async fn pass_higher_or_self(
|
||||
State(app_state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if let Some(user_and_target_user_and_request) =
|
||||
user_and_target_user_extraction(request, &app_state.database_connection).await
|
||||
{
|
||||
let user = user_and_target_user_and_request.user;
|
||||
let target_user = user_and_target_user_and_request.target_user;
|
||||
let mut request = user_and_target_user_and_request.request;
|
||||
|
||||
if User::is_higher_or_self(&user, &target_user).await {
|
||||
let user = Arc::new(user);
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue