Reconnection system using secret for the relay server #46

Merged
CoCo_Sol merged 7 commits from reconnect into main 2024-02-12 22:41:02 +00:00
6 changed files with 349 additions and 76 deletions

108
Cargo.lock generated
View file

@ -210,6 +210,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "anyhow"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
[[package]] [[package]]
name = "approx" name = "approx"
version = "0.5.1" version = "0.5.1"
@ -232,7 +238,7 @@ dependencies = [
"objc", "objc",
"objc-foundation", "objc-foundation",
"objc_id", "objc_id",
"parking_lot", "parking_lot 0.12.1",
"thiserror", "thiserror",
"winapi", "winapi",
"x11rb", "x11rb",
@ -554,7 +560,7 @@ dependencies = [
"futures-io", "futures-io",
"futures-lite 1.13.0", "futures-lite 1.13.0",
"js-sys", "js-sys",
"parking_lot", "parking_lot 0.12.1",
"ron", "ron",
"serde", "serde",
"thiserror", "thiserror",
@ -1611,7 +1617,7 @@ dependencies = [
"ndk-context", "ndk-context",
"oboe", "oboe",
"once_cell", "once_cell",
"parking_lot", "parking_lot 0.12.1",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"web-sys", "web-sys",
@ -1645,6 +1651,15 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.19" version = "0.8.19"
@ -1692,7 +1707,7 @@ dependencies = [
"hashbrown 0.14.3", "hashbrown 0.14.3",
"lock_api", "lock_api",
"once_cell", "once_cell",
"parking_lot_core", "parking_lot_core 0.9.9",
] ]
[[package]] [[package]]
@ -1808,7 +1823,7 @@ dependencies = [
"ecolor", "ecolor",
"emath", "emath",
"nohash-hasher", "nohash-hasher",
"parking_lot", "parking_lot 0.12.1",
] ]
[[package]] [[package]]
@ -1990,6 +2005,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "fs2"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
dependencies = [
"libc",
"winapi",
]
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.30" version = "0.3.30"
@ -2107,6 +2132,15 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.7" version = "0.14.7"
@ -3321,6 +3355,17 @@ version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
[[package]]
name = "parking_lot"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
dependencies = [
"instant",
"lock_api",
"parking_lot_core 0.8.6",
]
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.1" version = "0.12.1"
@ -3328,7 +3373,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [ dependencies = [
"lock_api", "lock_api",
"parking_lot_core", "parking_lot_core 0.9.9",
]
[[package]]
name = "parking_lot_core"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc"
dependencies = [
"cfg-if",
"instant",
"libc",
"redox_syscall 0.2.16",
"smallvec",
"winapi",
] ]
[[package]] [[package]]
@ -3543,6 +3602,15 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d463f2884048e7153449a55166f91028d5b0ea53c79377099ce4e8cf0cf9bb" checksum = "a0d463f2884048e7153449a55166f91028d5b0ea53c79377099ce4e8cf0cf9bb"
[[package]]
name = "redox_syscall"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
dependencies = [
"bitflags 1.3.2",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.3.5" version = "0.3.5"
@ -3615,22 +3683,26 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
name = "relay-client" name = "relay-client"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"home",
"log", "log",
"mio", "mio",
"rand", "rand",
"tungstenite", "tungstenite",
"uuid",
] ]
[[package]] [[package]]
name = "relay-server" name = "relay-server"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow",
"axum", "axum",
"dashmap", "dashmap",
"futures", "futures",
"lazy_static", "lazy_static",
"rand", "sled",
"tokio", "tokio",
"uuid",
] ]
[[package]] [[package]]
@ -3905,6 +3977,22 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "sled"
version = "0.34.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935"
dependencies = [
"crc32fast",
"crossbeam-epoch",
"crossbeam-utils",
"fs2",
"fxhash",
"libc",
"log",
"parking_lot 0.11.2",
]
[[package]] [[package]]
name = "slotmap" name = "slotmap"
version = "1.0.7" version = "1.0.7"
@ -4596,7 +4684,7 @@ dependencies = [
"js-sys", "js-sys",
"log", "log",
"naga", "naga",
"parking_lot", "parking_lot 0.12.1",
"profiling", "profiling",
"raw-window-handle", "raw-window-handle",
"smallvec", "smallvec",
@ -4621,7 +4709,7 @@ dependencies = [
"codespan-reporting", "codespan-reporting",
"log", "log",
"naga", "naga",
"parking_lot", "parking_lot 0.12.1",
"profiling", "profiling",
"raw-window-handle", "raw-window-handle",
"rustc-hash", "rustc-hash",
@ -4659,7 +4747,7 @@ dependencies = [
"metal", "metal",
"naga", "naga",
"objc", "objc",
"parking_lot", "parking_lot 0.12.1",
"profiling", "profiling",
"range-alloc", "range-alloc",
"raw-window-handle", "raw-window-handle",

View file

@ -8,4 +8,5 @@ FROM alpine:latest
WORKDIR /app WORKDIR /app
COPY --from=builder /app/target/release/relay-server . COPY --from=builder /app/target/release/relay-server .
EXPOSE 80/tcp EXPOSE 80/tcp
VOLUME [ "/data" ]
CMD ["./relay-server"] CMD ["./relay-server"]

View file

@ -14,5 +14,7 @@ workspace = true
[dependencies] [dependencies]
tungstenite = { version = "0.21.0", features = ["rustls-tls-native-roots"] } tungstenite = { version = "0.21.0", features = ["rustls-tls-native-roots"] }
mio = { version = "0.8.10", features = ["net", "os-poll"] } mio = { version = "0.8.10", features = ["net", "os-poll"] }
uuid = "1.7.0"
rand = "0.8.5" rand = "0.8.5"
home = "0.5.9"
log = "0.4.20" log = "0.4.20"

View file

@ -1,8 +1,10 @@
//! A library containing a client to use a relay server. //! A library containing a client to use a relay server.
use std::borrow::Cow; use std::borrow::Cow;
use std::fs;
use std::io::{self}; use std::io::{self};
use std::net::{SocketAddr, ToSocketAddrs}; use std::net::{SocketAddr, ToSocketAddrs};
use std::path::PathBuf;
use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::mpsc::{channel, Receiver, Sender};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -12,6 +14,7 @@ use rand::seq::SliceRandom;
use tungstenite::handshake::MidHandshake; use tungstenite::handshake::MidHandshake;
use tungstenite::stream::MaybeTlsStream; use tungstenite::stream::MaybeTlsStream;
use tungstenite::{ClientHandshake, HandshakeError, Message, WebSocket}; use tungstenite::{ClientHandshake, HandshakeError, Message, WebSocket};
use uuid::Uuid;
/// The state of a [Connection]. /// The state of a [Connection].
#[derive(Debug)] #[derive(Debug)]
@ -28,6 +31,12 @@ enum ConnectionState {
/// The websocket handshake is in progress. /// The websocket handshake is in progress.
Handshaking(MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>), Handshaking(MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>),
/// The websocket handshake is finished.
Handshaked(WebSocket<MaybeTlsStream<TcpStream>>),
/// The [Connection] is registering with the relay server.
Registering(WebSocket<MaybeTlsStream<TcpStream>>),
/// The [Connection] is connected. /// The [Connection] is connected.
Active(WebSocket<MaybeTlsStream<TcpStream>>), Active(WebSocket<MaybeTlsStream<TcpStream>>),
} }
@ -40,6 +49,15 @@ pub struct Connection {
/// The domain of the relay server. /// The domain of the relay server.
domain: String, domain: String,
/// The path to the file where the identifier and secret key are stored.
data_path: PathBuf,
/// The identifier of the connection for the relay server.
identifier: Option<Uuid>,
/// The secret key used to authenticate with the relay server.
secret: Option<Uuid>,
/// The receiver part of the send channel. /// The receiver part of the send channel.
/// ///
/// This is used in [Connection::update] to get messages that need to /// This is used in [Connection::update] to get messages that need to
@ -56,13 +74,13 @@ pub struct Connection {
/// ///
/// This is used in [Connection::read] to get messages that have been /// This is used in [Connection::read] to get messages that have been
/// received from the relay server. /// received from the relay server.
receive_receiver: Receiver<(u32, Vec<u8>)>, receive_receiver: Receiver<(Uuid, Vec<u8>)>,
/// The sender part of the send channel. /// The sender part of the send channel.
/// ///
/// This is used in [Connection::update] to store messages that have /// This is used in [Connection::update] to store messages that have
/// been received from the relay server. /// been received from the relay server.
receive_sender: Sender<(u32, Vec<u8>)>, receive_sender: Sender<(Uuid, Vec<u8>)>,
/// The state of the connection. /// The state of the connection.
state: ConnectionState, state: ConnectionState,
@ -72,11 +90,45 @@ impl Connection {
/// Create a new [Connection]. /// Create a new [Connection].
pub fn new<'a>(domain: impl Into<Cow<'a, str>>) -> io::Result<Self> { pub fn new<'a>(domain: impl Into<Cow<'a, str>>) -> io::Result<Self> {
let domain = domain.into(); let domain = domain.into();
// Loads the identifier and secret key from disk.
let (data_path, identifier, secret) = {
// Find the relay data file path.
let mut path = home::home_dir().ok_or_else(|| {
io::Error::new(io::ErrorKind::NotFound, "could not find home directory")
})?;
path.push(".relay-data");
// Check if the file exists.
match path.exists() {
true => {
// Read the file and parse the identifier and secret key.
let contents = fs::read(&path)?;
if contents.len() != 32 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid data in .relay-data",
));
}
let identifier = Uuid::from_slice(&contents[..16]).map_err(io::Error::other)?;
let secret = Uuid::from_slice(&contents[16..]).map_err(io::Error::other)?;
(path, Some(identifier), Some(secret))
}
false => (path, None, None),
}
};
// Create the communication channels.
let (send_sender, send_receiver) = channel(); let (send_sender, send_receiver) = channel();
let (receive_sender, receive_receiver) = channel(); let (receive_sender, receive_receiver) = channel();
// Create the connection and return it.
Ok(Self { Ok(Self {
address_list: (domain.as_ref(), 443).to_socket_addrs()?.collect(), address_list: (domain.as_ref(), 443).to_socket_addrs()?.collect(),
domain: domain.into_owned(), domain: domain.into_owned(),
data_path,
identifier,
secret,
send_receiver, send_receiver,
send_sender, send_sender,
receive_receiver, receive_receiver,
@ -85,15 +137,20 @@ impl Connection {
}) })
} }
/// Get the identifier of the connection.
pub const fn identifier(&self) -> Option<Uuid> {
self.identifier
}
/// Send a message to the target client. /// Send a message to the target client.
pub fn send(&self, target_id: u32, message: Cow<[u8]>) { pub fn send<'a>(&self, target_id: Uuid, message: impl Into<Cow<'a, [u8]>>) {
let mut data = message.into_owned(); let mut data = message.into().into_owned();
data.extend_from_slice(&target_id.to_be_bytes()); data.extend_from_slice(target_id.as_bytes());
self.send_sender.send(Message::Binary(data)).ok(); self.send_sender.send(Message::Binary(data)).ok();
} }
/// Receive a message from the target client. /// Receive a message from the relay connection.
pub fn read(&self) -> Option<(u32, Vec<u8>)> { pub fn read(&self) -> Option<(Uuid, Vec<u8>)> {
self.receive_receiver.try_recv().ok() self.receive_receiver.try_recv().ok()
} }
@ -151,7 +208,7 @@ impl Connection {
/// Start the websocket handshake. /// Start the websocket handshake.
fn start_handshake(&mut self, stream: TcpStream) -> ConnectionState { fn start_handshake(&mut self, stream: TcpStream) -> ConnectionState {
match tungstenite::client_tls(format!("wss://{}", self.domain), stream) { match tungstenite::client_tls(format!("wss://{}", self.domain), stream) {
Ok((socket, _)) => ConnectionState::Active(socket), Ok((socket, _)) => ConnectionState::Handshaked(socket),
Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake), Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake),
Err(HandshakeError::Failure(e)) => { Err(HandshakeError::Failure(e)) => {
warn!("handshake failed with the relay server: {e}"); warn!("handshake failed with the relay server: {e}");
@ -166,7 +223,7 @@ impl Connection {
handshake: MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>, handshake: MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>,
) -> ConnectionState { ) -> ConnectionState {
match handshake.handshake() { match handshake.handshake() {
Ok((socket, _)) => ConnectionState::Active(socket), Ok((socket, _)) => ConnectionState::Handshaked(socket),
Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake), Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake),
Err(HandshakeError::Failure(e)) => { Err(HandshakeError::Failure(e)) => {
warn!("handshake failed with the relay server: {e}"); warn!("handshake failed with the relay server: {e}");
@ -175,6 +232,77 @@ impl Connection {
} }
} }
/// Start authentication with the relay server.
fn start_authentication(
&mut self,
mut socket: WebSocket<MaybeTlsStream<TcpStream>>,
) -> ConnectionState {
match (self.identifier, self.secret) {
(Some(identifier), Some(secret)) => {
// Create the authentication message.
let mut data = Vec::with_capacity(32);
data.extend(identifier.as_bytes());
data.extend(secret.as_bytes());
// Send the authentication message.
match socket.send(Message::Binary(data)) {
Ok(()) => ConnectionState::Active(socket),
Err(e) => {
warn!("failed to send authentication message: {e}");
ConnectionState::Disconnected
}
}
}
_ => {
// Send empty authentication message to request a new identifier and secret key.
match socket.send(Message::Binary(vec![])) {
Ok(()) => ConnectionState::Registering(socket),
Err(e) => {
warn!("failed to send registration message: {e}");
ConnectionState::Disconnected
}
}
}
}
}
/// Wait for the registration response.
fn get_registration_response(
&mut self,
mut socket: WebSocket<MaybeTlsStream<TcpStream>>,
) -> ConnectionState {
match socket.read() {
Ok(message) => {
// Check the message length.
let data = message.into_data();
if data.len() != 32 {
warn!("received malformed registration response");
return ConnectionState::Disconnected;
}
// Extract the client identifier and secret.
self.identifier = Some(Uuid::from_slice(&data[..16]).expect("invalid identifier"));
self.secret = Some(Uuid::from_slice(&data[16..]).expect("invalid secret"));
// Save the client identifier and secret.
fs::write(&self.data_path, data).ok();
// Activate the connection.
ConnectionState::Active(socket)
}
Err(tungstenite::Error::Io(ref e))
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::Interrupted =>
{
ConnectionState::Registering(socket)
}
Err(e) => {
warn!("failed to receive registration response: {e}");
ConnectionState::Disconnected
}
}
}
/// Update the [Connection] by receiving and sending messages. /// Update the [Connection] by receiving and sending messages.
fn update_connection( fn update_connection(
&mut self, &mut self,
@ -203,18 +331,14 @@ impl Connection {
Ok(message) => { Ok(message) => {
// Check the message length. // Check the message length.
let mut data = message.into_data(); let mut data = message.into_data();
if data.len() < 4 { if data.len() < 16 {
warn!("received malformed message with length: {}", data.len()); warn!("received malformed message with length: {}", data.len());
continue; continue;
} }
// Extract the sender ID. // Extract the sender ID.
let id_start = data.len() - 4; let id_start = data.len() - 16;
let sender_id = u32::from_be_bytes( let sender_id = Uuid::from_slice(&data[id_start..]).expect("invalid sender id");
data[id_start..]
.try_into()
.unwrap_or_else(|_| unreachable!()),
);
data.truncate(id_start); data.truncate(id_start);
// Send the message to the receive channel. // Send the message to the receive channel.
@ -250,6 +374,8 @@ impl Connection {
ConnectionState::Connecting(stream, start) => self.check_connection(stream, start), ConnectionState::Connecting(stream, start) => self.check_connection(stream, start),
ConnectionState::Connected(stream) => self.start_handshake(stream), ConnectionState::Connected(stream) => self.start_handshake(stream),
ConnectionState::Handshaking(handshake) => self.continue_handshake(handshake), ConnectionState::Handshaking(handshake) => self.continue_handshake(handshake),
ConnectionState::Handshaked(socket) => self.start_authentication(socket),
ConnectionState::Registering(socket) => self.get_registration_response(socket),
ConnectionState::Active(socket) => self.update_connection(socket), ConnectionState::Active(socket) => self.update_connection(socket),
} }
} }

View file

@ -14,7 +14,9 @@ workspace = true
[dependencies] [dependencies]
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] }
axum = { version = "0.7.4", features = ["ws"] } axum = { version = "0.7.4", features = ["ws"] }
uuid = { version = "1.7.0", features = ["v4"] }
lazy_static = "1.4.0" lazy_static = "1.4.0"
futures = "0.3.30" futures = "0.3.30"
dashmap = "5.5.3" dashmap = "5.5.3"
rand = "0.8.5" anyhow = "1.0.79"
sled = "0.34.7"

View file

@ -1,5 +1,8 @@
//! A relay server for bevnet. //! A relay server for bevnet.
use std::io;
use anyhow::bail;
use axum::extract::ws::{Message, WebSocket}; use axum::extract::ws::{Message, WebSocket};
use axum::extract::WebSocketUpgrade; use axum::extract::WebSocketUpgrade;
use axum::routing::get; use axum::routing::get;
@ -7,19 +10,25 @@ use axum::Router;
use dashmap::DashMap; use dashmap::DashMap;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rand::Rng; use sled::transaction::{ConflictableTransactionResult, TransactionalTree};
use sled::{Db, IVec};
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle; use uuid::Uuid;
lazy_static! { lazy_static! {
static ref CLIENTS: DashMap<u32, Sender<Vec<u8>>> = DashMap::new(); static ref CLIENTS: DashMap<Uuid, Sender<Vec<u8>>> = DashMap::new();
static ref DB: Db = sled::open("/data/secrets.db").expect("unable to open the database");
} }
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let app = Router::new().route( let app = Router::new().route(
"/", "/",
get(|ws: WebSocketUpgrade| async { ws.on_upgrade(handle) }), get(|ws: WebSocketUpgrade| async {
ws.on_upgrade(|socket| async {
handle(socket).await.ok();
})
}),
); );
let listener = tokio::net::TcpListener::bind("0.0.0.0:80") let listener = tokio::net::TcpListener::bind("0.0.0.0:80")
.await .await
@ -27,80 +36,125 @@ async fn main() {
axum::serve(listener, app).await.expect("failed to serve"); axum::serve(listener, app).await.expect("failed to serve");
} }
/// Handle the websocket connection. /// Create a new client and add it to the database.
async fn handle(socket: WebSocket) { fn create_client(tx: &TransactionalTree) -> ConflictableTransactionResult<(Uuid, Uuid), io::Error> {
// Generate a new ID for the client. // Generates a new identifier for the client.
let client_id: u32 = loop { let client_id = loop {
let id = rand::thread_rng().gen(); // Generates a new random identifier.
if !CLIENTS.contains_key(&id) { let id = Uuid::new_v4();
// Check if the id isn't already in the database.
if tx.get(id.as_bytes())?.is_none() {
break id; break id;
} }
}; };
println!("Client({}) connected", client_id);
// Add the client to the list of connected clients. // Generate a random secret for the client.
let (sender, receiver) = channel(128); let secret = Uuid::new_v4();
CLIENTS.insert(client_id, sender);
// Handle messages from the client. // Add the new client to the database.
let result = handle_socket(socket, client_id, receiver).await; tx.insert(client_id.as_bytes(), secret.as_bytes())?;
// Remove the client from the list of connected clients. // Returns the client identifier and his secret.
match result { Ok((client_id, secret))
Ok(_) => println!("Client({}) disconnected", client_id),
Err(e) => {
CLIENTS.remove(&client_id);
println!("Client({}) disconnected: {}", client_id, e);
}
}
} }
/// Error prone part of handling the websocket connection. /// Handle the websocket connection.
async fn handle_socket( async fn handle(mut socket: WebSocket) -> anyhow::Result<()> {
mut socket: WebSocket, // Receive the first request from the client.
client_id: u32, let data = match socket.recv().await {
mut receiver: Receiver<Vec<u8>>, Some(Ok(message)) => message.into_data(),
) -> Result<(), axum::Error> { _ => return Ok(()),
// Send the client ID to the client. };
socket
.send(Message::Binary(client_id.to_be_bytes().to_vec()))
.await?;
// If the request is empty it means that the client want a new identifier and
// secret, so we create them and send them to the client.
let client_id = if data.is_empty() {
// Generate the new client.
let (client_id, secret) = DB.transaction(create_client)?;
DB.flush_async().await?;
println!("{client_id} created");
// Send the data to the client.
let mut data = Vec::with_capacity(32);
data.extend_from_slice(client_id.as_bytes());
data.extend_from_slice(secret.as_bytes());
socket.send(Message::Binary(data)).await?;
// Returns the client identifier.
client_id
}
// Otherwise it means that the client want to reuse an identifier, so it will
// send it along with his secret to prove that he is the right client.
else {
// Check for the message length to detect malformed messages.
if data.len() != 32 {
bail!("malformed message");
}
// Get the client identifier and secret from the message.
let client_id = Uuid::from_slice(&data[..16])?;
let secret = Uuid::from_slice(&data[16..])?;
// Check with the database if the secret is correct.
if DB.get(client_id.as_bytes())? != Some(IVec::from(secret.as_bytes())) {
bail!("invalid secret")
}
// Returns the client identifier.
client_id
};
// Handle the client connection.
println!("{client_id} connected");
let (sender, receiver) = channel(128);
CLIENTS.insert(client_id, sender);
handle_client(socket, client_id, receiver).await.ok();
CLIENTS.remove(&client_id);
println!("{client_id} disconnected");
// Returns success.
Ok(())
}
/// Handle the client connection.
async fn handle_client(
socket: WebSocket,
client_id: Uuid,
mut receiver: Receiver<Vec<u8>>,
) -> anyhow::Result<()> {
// Split the socket into sender and receiver. // Split the socket into sender and receiver.
let (mut writer, mut reader) = socket.split(); let (mut writer, mut reader) = socket.split();
// Handle sending messages to the client. // Handle sending messages to the client.
let sending_task: JoinHandle<Result<(), axum::Error>> = tokio::spawn(async move { tokio::spawn(async move {
while let Some(message) = receiver.recv().await { while let Some(message) = receiver.recv().await {
writer writer.send(Message::Binary(message)).await?;
.send(Message::Binary(message))
.await
.map_err(axum::Error::new)?;
} }
Ok(()) Ok::<(), axum::Error>(())
}); });
// Handle messages from the client. // Handle messages from the client.
while let Some(Ok(message)) = reader.next().await { while let Some(Ok(message)) = reader.next().await {
// Get the target ID from the message. // Get the target ID from the message.
let mut data = message.into_data(); let mut data = message.into_data();
let id_start = data.len() - 4; if data.len() < 16 {
let target_id = u32::from_be_bytes(data[id_start..].try_into().map_err(axum::Error::new)?); bail!("malformed message");
}
let id_start = data.len() - 16;
let target_id = Uuid::from_slice(&data[id_start..])?;
// Write the sender ID to the message. // Write the sender ID to the message.
for (i, byte) in client_id.to_be_bytes().into_iter().enumerate() { for (i, &byte) in client_id.as_bytes().iter().enumerate() {
data[id_start + i] = byte; data[id_start + i] = byte;
} }
// Send the message to the target client. // Send the message to the target client.
if let Some(sender) = CLIENTS.get(&target_id) { if let Some(sender) = CLIENTS.get(&target_id) {
sender.send(data).await.map_err(axum::Error::new)?; sender.send(data).await?;
} }
} }
// Remove the client from the list of connected clients. // Returns success.
CLIENTS.remove(&client_id); Ok(())
// Wait for the sender to finish.
sending_task.await.map_err(axum::Error::new)?
} }