generated from tipragot/rust
Reconnection system using secret for the relay server #46
108
Cargo.lock
generated
108
Cargo.lock
generated
|
@ -210,6 +210,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
|
@ -232,7 +238,7 @@ dependencies = [
|
|||
"objc",
|
||||
"objc-foundation",
|
||||
"objc_id",
|
||||
"parking_lot",
|
||||
"parking_lot 0.12.1",
|
||||
"thiserror",
|
||||
"winapi",
|
||||
"x11rb",
|
||||
|
@ -554,7 +560,7 @@ dependencies = [
|
|||
"futures-io",
|
||||
"futures-lite 1.13.0",
|
||||
"js-sys",
|
||||
"parking_lot",
|
||||
"parking_lot 0.12.1",
|
||||
"ron",
|
||||
"serde",
|
||||
"thiserror",
|
||||
|
@ -1611,7 +1617,7 @@ dependencies = [
|
|||
"ndk-context",
|
||||
"oboe",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"parking_lot 0.12.1",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
|
@ -1645,6 +1651,15 @@ dependencies = [
|
|||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.19"
|
||||
|
@ -1692,7 +1707,7 @@ dependencies = [
|
|||
"hashbrown 0.14.3",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
"parking_lot_core 0.9.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1808,7 +1823,7 @@ dependencies = [
|
|||
"ecolor",
|
||||
"emath",
|
||||
"nohash-hasher",
|
||||
"parking_lot",
|
||||
"parking_lot 0.12.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1990,6 +2005,16 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs2"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.30"
|
||||
|
@ -2107,6 +2132,15 @@ dependencies = [
|
|||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fxhash"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
|
@ -3321,6 +3355,17 @@ version = "2.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
|
||||
dependencies = [
|
||||
"instant",
|
||||
"lock_api",
|
||||
"parking_lot_core 0.8.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
|
@ -3328,7 +3373,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
"parking_lot_core 0.9.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"instant",
|
||||
"libc",
|
||||
"redox_syscall 0.2.16",
|
||||
"smallvec",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3543,6 +3602,15 @@ version = "0.4.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0d463f2884048e7153449a55166f91028d5b0ea53c79377099ce4e8cf0cf9bb"
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.3.5"
|
||||
|
@ -3615,22 +3683,26 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
|
|||
name = "relay-client"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"home",
|
||||
"log",
|
||||
"mio",
|
||||
"rand",
|
||||
"tungstenite",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "relay-server"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"dashmap",
|
||||
"futures",
|
||||
"lazy_static",
|
||||
"rand",
|
||||
"sled",
|
||||
"tokio",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3905,6 +3977,22 @@ dependencies = [
|
|||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sled"
|
||||
version = "0.34.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
"fs2",
|
||||
"fxhash",
|
||||
"libc",
|
||||
"log",
|
||||
"parking_lot 0.11.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slotmap"
|
||||
version = "1.0.7"
|
||||
|
@ -4596,7 +4684,7 @@ dependencies = [
|
|||
"js-sys",
|
||||
"log",
|
||||
"naga",
|
||||
"parking_lot",
|
||||
"parking_lot 0.12.1",
|
||||
"profiling",
|
||||
"raw-window-handle",
|
||||
"smallvec",
|
||||
|
@ -4621,7 +4709,7 @@ dependencies = [
|
|||
"codespan-reporting",
|
||||
"log",
|
||||
"naga",
|
||||
"parking_lot",
|
||||
"parking_lot 0.12.1",
|
||||
"profiling",
|
||||
"raw-window-handle",
|
||||
"rustc-hash",
|
||||
|
@ -4659,7 +4747,7 @@ dependencies = [
|
|||
"metal",
|
||||
"naga",
|
||||
"objc",
|
||||
"parking_lot",
|
||||
"parking_lot 0.12.1",
|
||||
"profiling",
|
||||
"range-alloc",
|
||||
"raw-window-handle",
|
||||
|
|
|
@ -8,4 +8,5 @@ FROM alpine:latest
|
|||
WORKDIR /app
|
||||
COPY --from=builder /app/target/release/relay-server .
|
||||
EXPOSE 80/tcp
|
||||
VOLUME [ "/data" ]
|
||||
CMD ["./relay-server"]
|
|
@ -14,5 +14,7 @@ workspace = true
|
|||
[dependencies]
|
||||
tungstenite = { version = "0.21.0", features = ["rustls-tls-native-roots"] }
|
||||
mio = { version = "0.8.10", features = ["net", "os-poll"] }
|
||||
uuid = "1.7.0"
|
||||
rand = "0.8.5"
|
||||
home = "0.5.9"
|
||||
log = "0.4.20"
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
//! A library containing a client to use a relay server.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fs;
|
||||
use std::io::{self};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::mpsc::{channel, Receiver, Sender};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
|
@ -12,6 +14,7 @@ use rand::seq::SliceRandom;
|
|||
use tungstenite::handshake::MidHandshake;
|
||||
use tungstenite::stream::MaybeTlsStream;
|
||||
use tungstenite::{ClientHandshake, HandshakeError, Message, WebSocket};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// The state of a [Connection].
|
||||
#[derive(Debug)]
|
||||
|
@ -28,6 +31,12 @@ enum ConnectionState {
|
|||
/// The websocket handshake is in progress.
|
||||
Handshaking(MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>),
|
||||
|
||||
/// The websocket handshake is finished.
|
||||
Handshaked(WebSocket<MaybeTlsStream<TcpStream>>),
|
||||
|
||||
/// The [Connection] is registering with the relay server.
|
||||
Registering(WebSocket<MaybeTlsStream<TcpStream>>),
|
||||
|
||||
/// The [Connection] is connected.
|
||||
Active(WebSocket<MaybeTlsStream<TcpStream>>),
|
||||
}
|
||||
|
@ -40,6 +49,15 @@ pub struct Connection {
|
|||
/// The domain of the relay server.
|
||||
domain: String,
|
||||
|
||||
/// The path to the file where the identifier and secret key are stored.
|
||||
data_path: PathBuf,
|
||||
|
||||
/// The identifier of the connection for the relay server.
|
||||
identifier: Option<Uuid>,
|
||||
|
||||
/// The secret key used to authenticate with the relay server.
|
||||
secret: Option<Uuid>,
|
||||
|
||||
/// The receiver part of the send channel.
|
||||
///
|
||||
/// This is used in [Connection::update] to get messages that need to
|
||||
|
@ -56,13 +74,13 @@ pub struct Connection {
|
|||
///
|
||||
/// This is used in [Connection::read] to get messages that have been
|
||||
/// received from the relay server.
|
||||
receive_receiver: Receiver<(u32, Vec<u8>)>,
|
||||
receive_receiver: Receiver<(Uuid, Vec<u8>)>,
|
||||
|
||||
/// The sender part of the send channel.
|
||||
///
|
||||
/// This is used in [Connection::update] to store messages that have
|
||||
/// been received from the relay server.
|
||||
receive_sender: Sender<(u32, Vec<u8>)>,
|
||||
receive_sender: Sender<(Uuid, Vec<u8>)>,
|
||||
|
||||
/// The state of the connection.
|
||||
state: ConnectionState,
|
||||
|
@ -72,11 +90,45 @@ impl Connection {
|
|||
/// Create a new [Connection].
|
||||
pub fn new<'a>(domain: impl Into<Cow<'a, str>>) -> io::Result<Self> {
|
||||
let domain = domain.into();
|
||||
|
||||
// Loads the identifier and secret key from disk.
|
||||
let (data_path, identifier, secret) = {
|
||||
// Find the relay data file path.
|
||||
let mut path = home::home_dir().ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::NotFound, "could not find home directory")
|
||||
})?;
|
||||
path.push(".relay-data");
|
||||
|
||||
// Check if the file exists.
|
||||
match path.exists() {
|
||||
true => {
|
||||
// Read the file and parse the identifier and secret key.
|
||||
let contents = fs::read(&path)?;
|
||||
if contents.len() != 32 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid data in .relay-data",
|
||||
));
|
||||
}
|
||||
let identifier = Uuid::from_slice(&contents[..16]).map_err(io::Error::other)?;
|
||||
let secret = Uuid::from_slice(&contents[16..]).map_err(io::Error::other)?;
|
||||
(path, Some(identifier), Some(secret))
|
||||
}
|
||||
false => (path, None, None),
|
||||
}
|
||||
};
|
||||
|
||||
// Create the communication channels.
|
||||
let (send_sender, send_receiver) = channel();
|
||||
let (receive_sender, receive_receiver) = channel();
|
||||
|
||||
// Create the connection and return it.
|
||||
Ok(Self {
|
||||
address_list: (domain.as_ref(), 443).to_socket_addrs()?.collect(),
|
||||
domain: domain.into_owned(),
|
||||
data_path,
|
||||
identifier,
|
||||
secret,
|
||||
send_receiver,
|
||||
send_sender,
|
||||
receive_receiver,
|
||||
|
@ -85,15 +137,20 @@ impl Connection {
|
|||
})
|
||||
}
|
||||
|
||||
/// Get the identifier of the connection.
|
||||
pub const fn identifier(&self) -> Option<Uuid> {
|
||||
self.identifier
|
||||
}
|
||||
|
||||
/// Send a message to the target client.
|
||||
pub fn send(&self, target_id: u32, message: Cow<[u8]>) {
|
||||
let mut data = message.into_owned();
|
||||
data.extend_from_slice(&target_id.to_be_bytes());
|
||||
pub fn send<'a>(&self, target_id: Uuid, message: impl Into<Cow<'a, [u8]>>) {
|
||||
let mut data = message.into().into_owned();
|
||||
data.extend_from_slice(target_id.as_bytes());
|
||||
self.send_sender.send(Message::Binary(data)).ok();
|
||||
}
|
||||
|
||||
/// Receive a message from the target client.
|
||||
pub fn read(&self) -> Option<(u32, Vec<u8>)> {
|
||||
/// Receive a message from the relay connection.
|
||||
pub fn read(&self) -> Option<(Uuid, Vec<u8>)> {
|
||||
self.receive_receiver.try_recv().ok()
|
||||
}
|
||||
|
||||
|
@ -151,7 +208,7 @@ impl Connection {
|
|||
/// Start the websocket handshake.
|
||||
fn start_handshake(&mut self, stream: TcpStream) -> ConnectionState {
|
||||
match tungstenite::client_tls(format!("wss://{}", self.domain), stream) {
|
||||
Ok((socket, _)) => ConnectionState::Active(socket),
|
||||
Ok((socket, _)) => ConnectionState::Handshaked(socket),
|
||||
Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake),
|
||||
Err(HandshakeError::Failure(e)) => {
|
||||
warn!("handshake failed with the relay server: {e}");
|
||||
|
@ -166,7 +223,7 @@ impl Connection {
|
|||
handshake: MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>,
|
||||
) -> ConnectionState {
|
||||
match handshake.handshake() {
|
||||
Ok((socket, _)) => ConnectionState::Active(socket),
|
||||
Ok((socket, _)) => ConnectionState::Handshaked(socket),
|
||||
Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake),
|
||||
Err(HandshakeError::Failure(e)) => {
|
||||
warn!("handshake failed with the relay server: {e}");
|
||||
|
@ -175,6 +232,77 @@ impl Connection {
|
|||
}
|
||||
}
|
||||
|
||||
/// Start authentication with the relay server.
|
||||
fn start_authentication(
|
||||
&mut self,
|
||||
mut socket: WebSocket<MaybeTlsStream<TcpStream>>,
|
||||
) -> ConnectionState {
|
||||
match (self.identifier, self.secret) {
|
||||
(Some(identifier), Some(secret)) => {
|
||||
// Create the authentication message.
|
||||
let mut data = Vec::with_capacity(32);
|
||||
data.extend(identifier.as_bytes());
|
||||
data.extend(secret.as_bytes());
|
||||
|
||||
// Send the authentication message.
|
||||
match socket.send(Message::Binary(data)) {
|
||||
Ok(()) => ConnectionState::Active(socket),
|
||||
Err(e) => {
|
||||
warn!("failed to send authentication message: {e}");
|
||||
ConnectionState::Disconnected
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Send empty authentication message to request a new identifier and secret key.
|
||||
match socket.send(Message::Binary(vec![])) {
|
||||
Ok(()) => ConnectionState::Registering(socket),
|
||||
Err(e) => {
|
||||
warn!("failed to send registration message: {e}");
|
||||
ConnectionState::Disconnected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for the registration response.
|
||||
fn get_registration_response(
|
||||
&mut self,
|
||||
mut socket: WebSocket<MaybeTlsStream<TcpStream>>,
|
||||
) -> ConnectionState {
|
||||
match socket.read() {
|
||||
Ok(message) => {
|
||||
// Check the message length.
|
||||
let data = message.into_data();
|
||||
if data.len() != 32 {
|
||||
warn!("received malformed registration response");
|
||||
return ConnectionState::Disconnected;
|
||||
}
|
||||
|
||||
// Extract the client identifier and secret.
|
||||
self.identifier = Some(Uuid::from_slice(&data[..16]).expect("invalid identifier"));
|
||||
self.secret = Some(Uuid::from_slice(&data[16..]).expect("invalid secret"));
|
||||
|
||||
// Save the client identifier and secret.
|
||||
fs::write(&self.data_path, data).ok();
|
||||
|
||||
// Activate the connection.
|
||||
ConnectionState::Active(socket)
|
||||
}
|
||||
Err(tungstenite::Error::Io(ref e))
|
||||
if e.kind() == std::io::ErrorKind::WouldBlock
|
||||
|| e.kind() == std::io::ErrorKind::Interrupted =>
|
||||
{
|
||||
ConnectionState::Registering(socket)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to receive registration response: {e}");
|
||||
ConnectionState::Disconnected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the [Connection] by receiving and sending messages.
|
||||
fn update_connection(
|
||||
&mut self,
|
||||
|
@ -203,18 +331,14 @@ impl Connection {
|
|||
Ok(message) => {
|
||||
// Check the message length.
|
||||
let mut data = message.into_data();
|
||||
if data.len() < 4 {
|
||||
if data.len() < 16 {
|
||||
warn!("received malformed message with length: {}", data.len());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract the sender ID.
|
||||
let id_start = data.len() - 4;
|
||||
let sender_id = u32::from_be_bytes(
|
||||
data[id_start..]
|
||||
.try_into()
|
||||
.unwrap_or_else(|_| unreachable!()),
|
||||
);
|
||||
let id_start = data.len() - 16;
|
||||
let sender_id = Uuid::from_slice(&data[id_start..]).expect("invalid sender id");
|
||||
data.truncate(id_start);
|
||||
|
||||
// Send the message to the receive channel.
|
||||
|
@ -250,6 +374,8 @@ impl Connection {
|
|||
ConnectionState::Connecting(stream, start) => self.check_connection(stream, start),
|
||||
ConnectionState::Connected(stream) => self.start_handshake(stream),
|
||||
ConnectionState::Handshaking(handshake) => self.continue_handshake(handshake),
|
||||
ConnectionState::Handshaked(socket) => self.start_authentication(socket),
|
||||
ConnectionState::Registering(socket) => self.get_registration_response(socket),
|
||||
ConnectionState::Active(socket) => self.update_connection(socket),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,9 @@ workspace = true
|
|||
[dependencies]
|
||||
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] }
|
||||
axum = { version = "0.7.4", features = ["ws"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
lazy_static = "1.4.0"
|
||||
futures = "0.3.30"
|
||||
dashmap = "5.5.3"
|
||||
rand = "0.8.5"
|
||||
anyhow = "1.0.79"
|
||||
sled = "0.34.7"
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
//! A relay server for bevnet.
|
||||
|
||||
use std::io;
|
||||
|
||||
use anyhow::bail;
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use axum::extract::WebSocketUpgrade;
|
||||
use axum::routing::get;
|
||||
|
@ -7,19 +10,25 @@ use axum::Router;
|
|||
use dashmap::DashMap;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use lazy_static::lazy_static;
|
||||
use rand::Rng;
|
||||
use sled::transaction::{ConflictableTransactionResult, TransactionalTree};
|
||||
use sled::{Db, IVec};
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
use tokio::task::JoinHandle;
|
||||
use uuid::Uuid;
|
||||
|
||||
lazy_static! {
|
||||
static ref CLIENTS: DashMap<u32, Sender<Vec<u8>>> = DashMap::new();
|
||||
static ref CLIENTS: DashMap<Uuid, Sender<Vec<u8>>> = DashMap::new();
|
||||
static ref DB: Db = sled::open("/data/secrets.db").expect("unable to open the database");
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(|ws: WebSocketUpgrade| async { ws.on_upgrade(handle) }),
|
||||
get(|ws: WebSocketUpgrade| async {
|
||||
ws.on_upgrade(|socket| async {
|
||||
handle(socket).await.ok();
|
||||
})
|
||||
}),
|
||||
);
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:80")
|
||||
.await
|
||||
|
@ -27,80 +36,125 @@ async fn main() {
|
|||
axum::serve(listener, app).await.expect("failed to serve");
|
||||
}
|
||||
|
||||
/// Handle the websocket connection.
|
||||
async fn handle(socket: WebSocket) {
|
||||
// Generate a new ID for the client.
|
||||
let client_id: u32 = loop {
|
||||
let id = rand::thread_rng().gen();
|
||||
if !CLIENTS.contains_key(&id) {
|
||||
/// Create a new client and add it to the database.
|
||||
fn create_client(tx: &TransactionalTree) -> ConflictableTransactionResult<(Uuid, Uuid), io::Error> {
|
||||
// Generates a new identifier for the client.
|
||||
let client_id = loop {
|
||||
// Generates a new random identifier.
|
||||
let id = Uuid::new_v4();
|
||||
|
||||
// Check if the id isn't already in the database.
|
||||
if tx.get(id.as_bytes())?.is_none() {
|
||||
break id;
|
||||
}
|
||||
};
|
||||
println!("Client({}) connected", client_id);
|
||||
|
||||
// Add the client to the list of connected clients.
|
||||
let (sender, receiver) = channel(128);
|
||||
CLIENTS.insert(client_id, sender);
|
||||
// Generate a random secret for the client.
|
||||
let secret = Uuid::new_v4();
|
||||
|
||||
// Handle messages from the client.
|
||||
let result = handle_socket(socket, client_id, receiver).await;
|
||||
// Add the new client to the database.
|
||||
tx.insert(client_id.as_bytes(), secret.as_bytes())?;
|
||||
|
||||
// Remove the client from the list of connected clients.
|
||||
match result {
|
||||
Ok(_) => println!("Client({}) disconnected", client_id),
|
||||
Err(e) => {
|
||||
CLIENTS.remove(&client_id);
|
||||
println!("Client({}) disconnected: {}", client_id, e);
|
||||
}
|
||||
}
|
||||
// Returns the client identifier and his secret.
|
||||
Ok((client_id, secret))
|
||||
}
|
||||
|
||||
/// Error prone part of handling the websocket connection.
|
||||
async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
client_id: u32,
|
||||
mut receiver: Receiver<Vec<u8>>,
|
||||
) -> Result<(), axum::Error> {
|
||||
// Send the client ID to the client.
|
||||
socket
|
||||
.send(Message::Binary(client_id.to_be_bytes().to_vec()))
|
||||
.await?;
|
||||
/// Handle the websocket connection.
|
||||
async fn handle(mut socket: WebSocket) -> anyhow::Result<()> {
|
||||
// Receive the first request from the client.
|
||||
let data = match socket.recv().await {
|
||||
Some(Ok(message)) => message.into_data(),
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
// If the request is empty it means that the client want a new identifier and
|
||||
// secret, so we create them and send them to the client.
|
||||
let client_id = if data.is_empty() {
|
||||
// Generate the new client.
|
||||
let (client_id, secret) = DB.transaction(create_client)?;
|
||||
DB.flush_async().await?;
|
||||
println!("{client_id} created");
|
||||
|
||||
// Send the data to the client.
|
||||
let mut data = Vec::with_capacity(32);
|
||||
data.extend_from_slice(client_id.as_bytes());
|
||||
data.extend_from_slice(secret.as_bytes());
|
||||
socket.send(Message::Binary(data)).await?;
|
||||
|
||||
// Returns the client identifier.
|
||||
client_id
|
||||
}
|
||||
// Otherwise it means that the client want to reuse an identifier, so it will
|
||||
// send it along with his secret to prove that he is the right client.
|
||||
else {
|
||||
// Check for the message length to detect malformed messages.
|
||||
if data.len() != 32 {
|
||||
bail!("malformed message");
|
||||
}
|
||||
|
||||
// Get the client identifier and secret from the message.
|
||||
let client_id = Uuid::from_slice(&data[..16])?;
|
||||
let secret = Uuid::from_slice(&data[16..])?;
|
||||
|
||||
// Check with the database if the secret is correct.
|
||||
if DB.get(client_id.as_bytes())? != Some(IVec::from(secret.as_bytes())) {
|
||||
bail!("invalid secret")
|
||||
}
|
||||
|
||||
// Returns the client identifier.
|
||||
client_id
|
||||
};
|
||||
|
||||
// Handle the client connection.
|
||||
println!("{client_id} connected");
|
||||
let (sender, receiver) = channel(128);
|
||||
CLIENTS.insert(client_id, sender);
|
||||
handle_client(socket, client_id, receiver).await.ok();
|
||||
CLIENTS.remove(&client_id);
|
||||
println!("{client_id} disconnected");
|
||||
|
||||
// Returns success.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle the client connection.
|
||||
async fn handle_client(
|
||||
socket: WebSocket,
|
||||
client_id: Uuid,
|
||||
mut receiver: Receiver<Vec<u8>>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Split the socket into sender and receiver.
|
||||
let (mut writer, mut reader) = socket.split();
|
||||
|
||||
// Handle sending messages to the client.
|
||||
let sending_task: JoinHandle<Result<(), axum::Error>> = tokio::spawn(async move {
|
||||
tokio::spawn(async move {
|
||||
while let Some(message) = receiver.recv().await {
|
||||
writer
|
||||
.send(Message::Binary(message))
|
||||
.await
|
||||
.map_err(axum::Error::new)?;
|
||||
writer.send(Message::Binary(message)).await?;
|
||||
}
|
||||
Ok(())
|
||||
Ok::<(), axum::Error>(())
|
||||
});
|
||||
|
||||
// Handle messages from the client.
|
||||
while let Some(Ok(message)) = reader.next().await {
|
||||
// Get the target ID from the message.
|
||||
let mut data = message.into_data();
|
||||
let id_start = data.len() - 4;
|
||||
let target_id = u32::from_be_bytes(data[id_start..].try_into().map_err(axum::Error::new)?);
|
||||
if data.len() < 16 {
|
||||
bail!("malformed message");
|
||||
}
|
||||
let id_start = data.len() - 16;
|
||||
let target_id = Uuid::from_slice(&data[id_start..])?;
|
||||
|
||||
// Write the sender ID to the message.
|
||||
for (i, byte) in client_id.to_be_bytes().into_iter().enumerate() {
|
||||
for (i, &byte) in client_id.as_bytes().iter().enumerate() {
|
||||
data[id_start + i] = byte;
|
||||
}
|
||||
|
||||
// Send the message to the target client.
|
||||
if let Some(sender) = CLIENTS.get(&target_id) {
|
||||
sender.send(data).await.map_err(axum::Error::new)?;
|
||||
sender.send(data).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the client from the list of connected clients.
|
||||
CLIENTS.remove(&client_id);
|
||||
|
||||
// Wait for the sender to finish.
|
||||
sending_task.await.map_err(axum::Error::new)?
|
||||
// Returns success.
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue