refactor: ♻️ new permission strategy part 2

This commit is contained in:
Ahmet Kaan GÜMÜŞ 2025-01-19 23:47:09 +03:00
parent 56aa04e32a
commit bcfcd2c6f0
15 changed files with 393 additions and 73 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
debug/
target/
Cargo.lock
cspell.json

View file

@ -1,6 +1,6 @@
-- Add up migration script here
CREATE TABLE IF NOT EXISTS "user"(
id BIGSERIAL PRIMARY KEY NOT NULL UNIQUE,
user_id BIGSERIAL PRIMARY KEY NOT NULL UNIQUE,
name VARCHAR(256) NOT NULL,
surname VARCHAR(256) NOT NULL,
gender boolean NOT NULL,

View file

@ -1,6 +1,6 @@
-- Add up migration script here
CREATE TABLE IF NOT EXISTS "post"(
creation_time TIMESTAMPTZ PRIMARY KEY UNIQUE NOT NULL DEFAULT NOW(),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
post VARCHAR(8192) NOT NULL UNIQUE
);

View file

@ -2,6 +2,6 @@
CREATE TABLE IF NOT EXISTS "comment"(
creation_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
comment VARCHAR(8192) NOT NULL
);

View file

@ -2,6 +2,6 @@
CREATE TABLE IF NOT EXISTS "post_interaction"(
interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
post_creation_time TIMESTAMPTZ NOT NULL REFERENCES "post"(creation_time),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id)
);

View file

@ -2,6 +2,6 @@
CREATE TABLE IF NOT EXISTS "comment_interaction"(
interaction_time TIMESTAMPTZ PRIMARY KEY NOT NULL UNIQUE DEFAULT NOW(),
comment_creation_time TIMESTAMPTZ NOT NULL REFERENCES "comment"(creation_time),
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
interaction_id BIGSERIAL NOT NULL REFERENCES "interaction"(id)
);

View file

@ -1,6 +1,6 @@
-- Add up migration script here
CREATE TABLE IF NOT EXISTS "user_contact"(
user_id BIGSERIAL NOT NULL REFERENCES "user"(id),
user_id BIGSERIAL NOT NULL REFERENCES "user"(user_id),
contact_id BIGSERIAL NOT NULL REFERENCES "contact"(id),
PRIMARY KEY (user_id, contact_id)
);

View file

@ -1,6 +1,6 @@
-- Add up migration script here
CREATE TABLE IF NOT EXISTS "login" (
user_id BIGSERIAL NOT NULL REFERENCES "user" (id),
user_id BIGSERIAL NOT NULL REFERENCES "user" (user_id),
token VARCHAR(1024) NOT NULL,
token_creation_time TIMESTAMPTZ NOT NULL DEFAULT NOW (),
PRIMARY KEY (user_id, token)

View file

@ -31,7 +31,7 @@ pub async fn read(id: &i64, database_connection: &Pool<Postgres>) -> Result<User
sqlx::query_as!(
User,
r#"
SELECT * FROM "user" WHERE "id" = $1
SELECT * FROM "user" WHERE "user_id" = $1
"#,
id
)
@ -50,7 +50,7 @@ pub async fn update(
) -> Result<User, sqlx::Error> {
sqlx::query_as!(User,
r#"
UPDATE "user" SET "name" = $2, "surname" = $3, "gender" = $4, "birth_date" = $5, "role_id" = $6 WHERE "id" = $1
UPDATE "user" SET "name" = $2, "surname" = $3, "gender" = $4, "birth_date" = $5, "role_id" = $6 WHERE "user_id" = $1
RETURNING *
"#, id, name, surname, gender, birth_date, role_id).fetch_one(database_connection).await
}
@ -59,7 +59,7 @@ pub async fn delete(id: &i64, database_connection: &Pool<Postgres>) -> Result<Us
sqlx::query_as!(
User,
r#"
DELETE FROM "user" WHERE "id" = $1
DELETE FROM "user" WHERE "user_id" = $1
RETURNING *
"#,
id
@ -157,13 +157,13 @@ pub async fn read_all_for_gender(
pub async fn read_all_id(database_connection: &Pool<Postgres>) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!(
r#"
SELECT "id" FROM "user"
SELECT "user_id" FROM "user"
"#,
)
.fetch_all(database_connection)
.await?
.iter()
.map(|record| record.id)
.map(|record| record.user_id)
.collect::<Vec<i64>>())
}
@ -173,14 +173,14 @@ pub async fn read_all_id_for_name(
) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!(
r#"
SELECT "id" FROM "user" WHERE "name" = $1
SELECT "user_id" FROM "user" WHERE "name" = $1
"#,
name
)
.fetch_all(database_connection)
.await?
.iter()
.map(|record| record.id)
.map(|record| record.user_id)
.collect::<Vec<i64>>())
}
@ -190,14 +190,14 @@ pub async fn read_all_id_for_surname(
) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!(
r#"
SELECT "id" FROM "user" WHERE "surname" = $1
SELECT "user_id" FROM "user" WHERE "surname" = $1
"#,
surname
)
.fetch_all(database_connection)
.await?
.iter()
.map(|record| record.id)
.map(|record| record.user_id)
.collect::<Vec<i64>>())
}
@ -207,14 +207,14 @@ pub async fn read_all_id_for_birth_date(
) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!(
r#"
SELECT "id" FROM "user" WHERE "birth_date" = $1
SELECT "user_id" FROM "user" WHERE "birth_date" = $1
"#,
birth_date
)
.fetch_all(database_connection)
.await?
.iter()
.map(|record| record.id)
.map(|record| record.user_id)
.collect::<Vec<i64>>())
}
@ -224,14 +224,14 @@ pub async fn read_all_id_for_role(
) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!(
r#"
SELECT "id" FROM "user" WHERE "role_id" = $1
SELECT "user_id" FROM "user" WHERE "role_id" = $1
"#,
role_id
)
.fetch_all(database_connection)
.await?
.iter()
.map(|record| record.id)
.map(|record| record.user_id)
.collect::<Vec<i64>>())
}
@ -241,21 +241,21 @@ pub async fn read_all_id_for_gender(
) -> Result<Vec<i64>, sqlx::Error> {
Ok(sqlx::query!(
r#"
SELECT "id" FROM "user" WHERE "gender" = $1
SELECT "user_id" FROM "user" WHERE "gender" = $1
"#,
gender
)
.fetch_all(database_connection)
.await?
.iter()
.map(|record| record.id)
.map(|record| record.user_id)
.collect::<Vec<i64>>())
}
pub async fn count_all(database_connection: &Pool<Postgres>) -> Result<u64, sqlx::Error> {
sqlx::query!(
r#"
SELECT COUNT(id) FROM "user"
SELECT COUNT(user_id) FROM "user"
"#,
)
.fetch_one(database_connection)
@ -272,7 +272,7 @@ pub async fn count_all_for_name(
) -> Result<u64, sqlx::Error> {
sqlx::query!(
r#"
SELECT COUNT(id) FROM "user" WHERE "name" = $1
SELECT COUNT(user_id) FROM "user" WHERE "name" = $1
"#,
name
)
@ -290,7 +290,7 @@ pub async fn count_all_for_surname(
) -> Result<u64, sqlx::Error> {
sqlx::query!(
r#"
SELECT COUNT(id) FROM "user" WHERE "surname" = $1
SELECT COUNT(user_id) FROM "user" WHERE "surname" = $1
"#,
surname
)
@ -308,7 +308,7 @@ pub async fn count_all_for_birth_date(
) -> Result<u64, sqlx::Error> {
sqlx::query!(
r#"
SELECT COUNT(id) FROM "user" WHERE "birth_date" = $1
SELECT COUNT(user_id) FROM "user" WHERE "birth_date" = $1
"#,
birth_date
)
@ -326,7 +326,7 @@ pub async fn count_all_for_role(
) -> Result<u64, sqlx::Error> {
sqlx::query!(
r#"
SELECT COUNT(id) FROM "user" WHERE "role_id" = $1
SELECT COUNT(user_id) FROM "user" WHERE "role_id" = $1
"#,
role_id
)
@ -344,7 +344,7 @@ pub async fn count_all_for_gender(
) -> Result<u64, sqlx::Error> {
sqlx::query!(
r#"
SELECT COUNT(id) FROM "user" WHERE "gender" = $1
SELECT COUNT(user_id) FROM "user" WHERE "gender" = $1
"#,
gender
)

View file

@ -51,3 +51,22 @@ impl std::error::Error for ForumMailError {
self.source()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ForumAuthError {
TokenRefreshTimeOver,
}
impl std::fmt::Display for ForumAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ForumAuthError::TokenRefreshTimeOver => write!(f, "Token Refresh Time is Over"),
}
}
}
impl std::error::Error for ForumAuthError {
fn cause(&self) -> Option<&dyn std::error::Error> {
self.source()
}
}

View file

@ -7,6 +7,8 @@ use crate::{
ONE_TIME_PASSWORDS,
};
use super::user::User;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct OneTimePassword {
pub user_id: i64,
@ -18,19 +20,15 @@ impl OneTimePassword {
RwLock::new(vec![])
}
pub async fn new(
user_id: &i64,
user_name: &String,
user_email: &String,
) -> Result<(), ForumMailError> {
pub async fn new(user: &User, user_email: &String) -> Result<(), ForumMailError> {
let one_time_password = "123".to_owned();
let new_self = Self {
user_id: *user_id,
user_id: user.user_id,
one_time_password,
};
let mail_template =
MailTemplate::OneTimePassword(MailFieldsOneTimePassword::new(user_name, &new_self));
MailTemplate::OneTimePassword(MailFieldsOneTimePassword::new(&user.name, &new_self));
mail_template.send_mail(user_email).await?;

View file

@ -2,18 +2,20 @@ use std::sync::LazyLock;
use chrono::{DateTime, Utc};
use jwt_simple::{
claims::Claims,
claims::{Claims, JWTClaims},
common::VerificationOptions,
prelude::{HS256Key, MACLike},
};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Postgres};
use crate::{database::login, SERVER_CONFIG};
use crate::{database::login, error::ForumAuthError, SERVER_CONFIG};
use super::user::User;
static TOKEN_META: LazyLock<TokenMeta> = LazyLock::new(TokenMeta::init);
struct TokenMeta {
pub struct TokenMeta {
token_key: HS256Key,
token_verification_options: Option<VerificationOptions>,
}
@ -30,27 +32,28 @@ impl TokenMeta {
}
}
async fn create_token() -> Option<String> {
async fn create_token(user_id: &i64) -> Option<String> {
let key = &TOKEN_META.token_key;
let claims = Claims::create(jwt_simple::prelude::Duration::from_mins(
let claims = Claims::with_custom_claims(
*user_id,
jwt_simple::prelude::Duration::from_mins(
SERVER_CONFIG.login_token_expiration_time_limit as u64,
));
),
);
let token = key.authenticate(claims).unwrap();
match TokenMeta::verify_token(&token).await {
true => Some(token),
false => None,
Ok(_) => Some(token),
Err(_) => None,
}
}
async fn verify_token(token: &String) -> bool {
pub async fn verify_token(token: &String) -> Result<JWTClaims<i64>, jwt_simple::Error> {
let token_meta = &TOKEN_META;
token_meta
.token_key
.verify_token::<jwt_simple::prelude::NoCustomClaims>(
token,
token_meta.token_verification_options.clone(),
)
.is_ok()
.verify_token::<i64>(token, token_meta.token_verification_options.clone())
}
}
@ -66,7 +69,9 @@ impl Login {
user_id: &i64,
database_connection: &Pool<Postgres>,
) -> Result<Login, sqlx::Error> {
let token = TokenMeta::create_token()
User::read(user_id, database_connection).await?;
let token = TokenMeta::create_token(user_id)
.await
.expect("Should not panic if it isn't configured wrong");
login::create(user_id, &token, database_connection).await
@ -77,6 +82,8 @@ impl Login {
token: &String,
database_connection: &Pool<Postgres>,
) -> Result<Login, sqlx::Error> {
User::read(user_id, database_connection).await?;
login::read(user_id, token, database_connection).await
}
@ -84,21 +91,21 @@ impl Login {
user_id: &i64,
token: &String,
database_connection: &Pool<Postgres>,
) -> Result<Login, sqlx::Error> {
) -> Result<Login, Box<dyn std::error::Error>> {
let login = Login::read(user_id, token, database_connection).await?;
match TokenMeta::verify_token(token).await {
true => Ok(login),
false => {
Ok(_) => Ok(login),
Err(_) => {
if DateTime::<Utc>::default()
.signed_duration_since(&login.token_creation_time)
.num_minutes()
<= SERVER_CONFIG.login_token_refresh_time_limit as i64
{
Login::delete(user_id, token, database_connection).await?;
Login::create(user_id, database_connection).await
} else {
let login = Login::create(user_id, database_connection).await?;
Ok(login)
} else {
Err(Box::new(ForumAuthError::TokenRefreshTimeOver))
}
}
}

View file

@ -13,7 +13,7 @@ pub struct Contact {
#[derive(Debug, Serialize, Deserialize)]
pub struct User {
pub id: i64,
pub user_id: i64,
pub name: String,
pub surname: String,
pub gender: bool,
@ -33,12 +33,15 @@ impl User {
user::create(name, surname, gender, birth_date, database_connection).await
}
pub async fn read(id: &i64, database_connection: &Pool<Postgres>) -> Result<User, sqlx::Error> {
user::read(id, database_connection).await
pub async fn read(
user_id: &i64,
database_connection: &Pool<Postgres>,
) -> Result<User, sqlx::Error> {
user::read(user_id, database_connection).await
}
pub async fn update(
id: &i64,
user_id: &i64,
name: &String,
surname: &String,
gender: &bool,
@ -47,7 +50,7 @@ impl User {
database_connection: &Pool<Postgres>,
) -> Result<User, sqlx::Error> {
user::update(
id,
user_id,
name,
surname,
gender,
@ -59,10 +62,10 @@ impl User {
}
pub async fn delete(
id: &i64,
user_id: &i64,
database_connection: &Pool<Postgres>,
) -> Result<User, sqlx::Error> {
user::delete(id, database_connection).await
user::delete(user_id, database_connection).await
}
pub async fn read_all(database_connection: &Pool<Postgres>) -> Result<Vec<User>, sqlx::Error> {
@ -183,4 +186,62 @@ impl User {
) -> Result<u64, sqlx::Error> {
user::count_all_for_gender(gender, database_connection).await
}
pub async fn is_builder(user: &User) -> bool {
if user.role_id == 0 {
true
} else {
false
}
}
pub async fn is_admin(user: &User) -> bool {
if user.role_id == 1 {
true
} else {
false
}
}
pub async fn is_banned(user: &User) -> bool {
if user.role_id == -1 {
true
} else {
false
}
}
pub async fn is_builder_or_admin(user: &User) -> bool {
if user.role_id == 0 || user.role_id == 1 {
true
} else {
false
}
}
pub async fn is_self(user: &User, target_user: &User) -> bool {
if user.user_id == target_user.user_id {
true
} else {
false
}
}
pub async fn is_higher(user: &User, target_user: &User) -> bool {
if user.user_id >= 0 {
if user.user_id < target_user.user_id {
return true;
}
}
false
}
pub async fn is_higher_or_self(user: &User, target_user: &User) -> bool {
if User::is_self(user, target_user).await {
true
} else {
User::is_higher(user, target_user).await
}
}
}

View file

@ -3,6 +3,7 @@ pub mod comment_interaction;
pub mod contact;
pub mod interaction;
pub mod login;
pub mod middleware;
pub mod post;
pub mod post_interaction;
pub mod role;
@ -18,6 +19,10 @@ use crate::{database, AppState};
pub async fn route(concurrency_limit: &usize, State(app_state): State<AppState>) -> Router {
Router::new()
.route("/", get(alive))
.route_layer(axum::middleware::from_fn_with_state(
app_state.clone(),
middleware::pass,
))
.nest(
"/roles",
role::route(axum::extract::State(app_state.clone())),
@ -59,7 +64,7 @@ pub async fn route(concurrency_limit: &usize, State(app_state): State<AppState>)
.with_state(app_state)
}
async fn alive(State(app_state): State<AppState>) -> impl IntoResponse {
pub async fn alive(State(app_state): State<AppState>) -> impl IntoResponse {
match database::is_alive(&app_state.database_connection).await {
true => StatusCode::OK,
false => StatusCode::SERVICE_UNAVAILABLE,

229
src/routing/middleware.rs Normal file
View file

@ -0,0 +1,229 @@
use std::sync::Arc;
use axum::{
body::{to_bytes, Body},
extract::{Request, State},
http::{self, StatusCode},
middleware::Next,
response::IntoResponse,
};
use sqlx::{Pool, Postgres};
use crate::{
feature::{login::TokenMeta, user::User},
AppState,
};
#[derive(Debug)]
struct UserAndRequest {
user: User,
request: Request,
}
#[derive(Debug)]
struct UserAndTargetUserAndRequest {
user: User,
target_user: User,
request: Request,
}
async fn user_extraction(
request: Request,
database_connection: &Pool<Postgres>,
) -> Option<UserAndRequest> {
if let Some(authorization_header) = request.headers().get(http::header::AUTHORIZATION) {
if let Ok(authorization_header) = authorization_header.to_str() {
if let Some((bearer, authorization_header)) = authorization_header.split_once(' ') {
if bearer == "bearer" {
if let Ok(claims) =
TokenMeta::verify_token(&authorization_header.to_string()).await
{
return Some(UserAndRequest {
user: User::read(&claims.custom, database_connection).await.ok()?,
request,
});
}
}
}
}
}
None
}
async fn target_user_extraction_from_json(
json: &serde_json::Value,
database_connection: &Pool<Postgres>,
) -> Option<User> {
if let Some(target_user_id) = json.get("user_id") {
if target_user_id.is_i64() {
if let Some(target_user_id) = target_user_id.as_i64() {
return User::read(&target_user_id, database_connection).await.ok();
}
}
}
None
}
async fn user_and_target_user_extraction(
request: Request,
database_connection: &Pool<Postgres>,
) -> Option<UserAndTargetUserAndRequest> {
let user_and_request = user_extraction(request, database_connection).await?;
let user = user_and_request.user;
let request = user_and_request.request;
let (parts, body) = request.into_parts();
let bytes = to_bytes(body, usize::MAX).await.ok()?;
let json: serde_json::Value = serde_json::from_slice(&bytes).ok()?;
let body = Body::from(json.to_string());
let request = Request::from_parts(parts, body);
Some(UserAndTargetUserAndRequest {
user,
target_user: target_user_extraction_from_json(&json, database_connection).await?,
request,
})
}
pub async fn pass(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
match user_extraction(request, &app_state.database_connection).await {
Some(user_and_request) => {
let user = Arc::new(user_and_request.user);
let mut request = user_and_request.request;
request.extensions_mut().insert(user);
Ok(next.run(request).await)
}
None => Err(StatusCode::FORBIDDEN),
}
}
pub async fn pass_builder(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
let user = user_and_request.user;
let mut request = user_and_request.request;
if User::is_builder(&user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_admin(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
let user = user_and_request.user;
let mut request = user_and_request.request;
if User::is_admin(&user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_builder_or_admin(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_request) = user_extraction(request, &app_state.database_connection).await {
let user = user_and_request.user;
let mut request = user_and_request.request;
if User::is_builder_or_admin(&user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_self(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_target_user_and_request) =
user_and_target_user_extraction(request, &app_state.database_connection).await
{
let user = user_and_target_user_and_request.user;
let target_user = user_and_target_user_and_request.target_user;
let mut request = user_and_target_user_and_request.request;
if User::is_self(&user, &target_user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_higher(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_target_user_and_request) =
user_and_target_user_extraction(request, &app_state.database_connection).await
{
let user = user_and_target_user_and_request.user;
let target_user = user_and_target_user_and_request.target_user;
let mut request = user_and_target_user_and_request.request;
if User::is_higher(&user, &target_user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}
pub async fn pass_higher_or_self(
State(app_state): State<AppState>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
if let Some(user_and_target_user_and_request) =
user_and_target_user_extraction(request, &app_state.database_connection).await
{
let user = user_and_target_user_and_request.user;
let target_user = user_and_target_user_and_request.target_user;
let mut request = user_and_target_user_and_request.request;
if User::is_higher_or_self(&user, &target_user).await {
let user = Arc::new(user);
request.extensions_mut().insert(user);
return Ok(next.run(request).await);
}
}
Err(StatusCode::FORBIDDEN)
}