diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index c3520e7..040e20a 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -10,6 +10,7 @@ pub struct User { #[derive(Debug, Clone, Serialize, Deserialize)] pub enum Error { + UnregisteredUser, Window, Permission(String), MediaPlay(String), @@ -43,6 +44,7 @@ impl std::error::Error for Error { impl Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Error::UnregisteredUser => write!(f, "No Registered User"), Error::Window => write!(f, "Window"), Error::Permission(permission_cause) => write!(f, "Permission | {}", permission_cause), Error::MediaPlay(err_val) => write!(f, "Media Play | {}", err_val), diff --git a/server/src/signal.rs b/server/src/signal.rs index d1c4f55..be2bc02 100644 --- a/server/src/signal.rs +++ b/server/src/signal.rs @@ -1,11 +1,15 @@ -use std::{collections::VecDeque, sync::LazyLock, time::Duration}; +use std::{ + collections::VecDeque, + sync::{Arc, LazyLock}, + time::Duration, +}; use axum::{Router, http::StatusCode, response::IntoResponse, routing::get}; use fastwebsockets::{ Frame, OpCode, WebSocketError, upgrade::{IncomingUpgrade, UpgradeFut}, }; -use protocol::{Signal, SignalType}; +use protocol::{Error, Signal, SignalType}; use tokio::{ net::TcpListener, sync::{RwLock, broadcast}, @@ -46,6 +50,22 @@ async fn signal(websocket: IncomingUpgrade) -> impl IntoResponse { response } +async fn remove_user_from_online_users(user: &String) -> Result<(), Error> { + let mut target_index = None; + let mut online_users = ONLINE_USERS.write().await; + for (index, online_user) in online_users.iter().enumerate() { + if online_user.user == *user { + target_index = Some(index); + } + } + if let Some(target_index) = target_index { + online_users.remove(target_index).expect("Should Not"); + Ok(()) + } else { + Err(Error::UnregisteredUser) + } +} + async fn websocket_handler(websocket: UpgradeFut) { let mut websocket = websocket.await.unwrap(); websocket.set_auto_pong(true); @@ -53,23 +73,36 @@ async fn websocket_handler(websocket: UpgradeFut) { websocket.set_auto_close(true); let (mut websocket_receiver, mut websocker_sender) = websocket.split(tokio::io::split); - let mut user = String::default(); + let user = Arc::new(RwLock::new(String::default())); let (message_sender, message_receiver) = broadcast::channel(100); + let user_for_receiver_first_connection_disconnect_check = user.clone(); if let Ok(received_frame) = websocket_receiver - .read_frame::<_, WebSocketError>(&mut move |_| async { Ok(()) }) + .read_frame::<_, WebSocketError>(&mut move |_| { + let user_for_receiver_disconnect_check = + user_for_receiver_first_connection_disconnect_check.clone(); + async move { + let _ = remove_user_from_online_users( + &*user_for_receiver_disconnect_check.read().await, + ) + .await; + Ok(()) + } + }) .await { if let OpCode::Text = received_frame.opcode { let signal = serde_json::from_slice::(&received_frame.payload.to_vec()).unwrap(); - if signal.get_signal_type() == SignalType::Auth && user == String::default() { + if signal.get_signal_type() == SignalType::Auth + && *user.read().await == String::default() + { let new_user = UserMessages { user: signal.get_data(), message_receiver, }; - user = new_user.user.to_owned(); + *user.write().await = new_user.user.to_owned(); ONLINE_USERS.write().await.push_back(new_user); } else { return; @@ -77,6 +110,7 @@ async fn websocket_handler(websocket: UpgradeFut) { } } + let user_for_sender = user.clone(); tokio::spawn(async move { while ONLINE_USERS.read().await.len() < 2 { sleep(Duration::from_secs(1)).await; @@ -84,7 +118,9 @@ async fn websocket_handler(websocket: UpgradeFut) { loop { let mut user_messages = ONLINE_USERS.write().await; for user_message in user_messages.iter_mut() { - if user_message.user != user && user_message.message_receiver.len() > 0 { + if user_message.user != *user_for_sender.read().await + && user_message.message_receiver.len() > 0 + { while let Ok(message) = user_message.message_receiver.recv().await { if let Err(err_val) = websocker_sender .write_frame(Frame::text(fastwebsockets::Payload::Owned( @@ -93,6 +129,10 @@ async fn websocket_handler(websocket: UpgradeFut) { .await { eprintln!("Error: WebSocket Send | {}", err_val); + let _ = remove_user_from_online_users( + &user_for_sender.read().await.to_owned(), + ) + .await; break; } if user_message.message_receiver.len() < 1 { @@ -106,7 +146,10 @@ async fn websocket_handler(websocket: UpgradeFut) { }); while let Ok(received_frame) = websocket_receiver - .read_frame::<_, WebSocketError>(&mut move |_| async { Ok(()) }) + .read_frame::<_, WebSocketError>(&mut |_| async { + let _ = remove_user_from_online_users(&user.read().await.to_owned()).await; + Ok(()) + }) .await { if let OpCode::Text = received_frame.opcode { @@ -116,8 +159,10 @@ async fn websocket_handler(websocket: UpgradeFut) { if signal.get_signal_type() != SignalType::Auth { if let Err(err_val) = message_sender.send(signal) { eprintln!("Error: WebSocket Channel Send | {}", err_val); + let _ = remove_user_from_online_users(&user.read().await.to_owned()).await; } } else { + let _ = remove_user_from_online_users(&user.read().await.to_owned()).await; return; } }