Reconnection system using secret for the relay server #46

Merged
CoCo_Sol merged 7 commits from reconnect into main 2024-02-12 22:41:02 +00:00
6 changed files with 349 additions and 76 deletions

108
Cargo.lock generated
View file

@ -210,6 +210,12 @@ dependencies = [
"libc",
]
[[package]]
name = "anyhow"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
[[package]]
name = "approx"
version = "0.5.1"
@ -232,7 +238,7 @@ dependencies = [
"objc",
"objc-foundation",
"objc_id",
"parking_lot",
"parking_lot 0.12.1",
"thiserror",
"winapi",
"x11rb",
@ -554,7 +560,7 @@ dependencies = [
"futures-io",
"futures-lite 1.13.0",
"js-sys",
"parking_lot",
"parking_lot 0.12.1",
"ron",
"serde",
"thiserror",
@ -1611,7 +1617,7 @@ dependencies = [
"ndk-context",
"oboe",
"once_cell",
"parking_lot",
"parking_lot 0.12.1",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
@ -1645,6 +1651,15 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.19"
@ -1692,7 +1707,7 @@ dependencies = [
"hashbrown 0.14.3",
"lock_api",
"once_cell",
"parking_lot_core",
"parking_lot_core 0.9.9",
]
[[package]]
@ -1808,7 +1823,7 @@ dependencies = [
"ecolor",
"emath",
"nohash-hasher",
"parking_lot",
"parking_lot 0.12.1",
]
[[package]]
@ -1990,6 +2005,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fs2"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "futures"
version = "0.3.30"
@ -2107,6 +2132,15 @@ dependencies = [
"slab",
]
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@ -3321,6 +3355,17 @@ version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
[[package]]
name = "parking_lot"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
dependencies = [
"instant",
"lock_api",
"parking_lot_core 0.8.6",
]
[[package]]
name = "parking_lot"
version = "0.12.1"
@ -3328,7 +3373,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core",
"parking_lot_core 0.9.9",
]
[[package]]
name = "parking_lot_core"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc"
dependencies = [
"cfg-if",
"instant",
"libc",
"redox_syscall 0.2.16",
"smallvec",
"winapi",
]
[[package]]
@ -3543,6 +3602,15 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d463f2884048e7153449a55166f91028d5b0ea53c79377099ce4e8cf0cf9bb"
[[package]]
name = "redox_syscall"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "redox_syscall"
version = "0.3.5"
@ -3615,22 +3683,26 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
name = "relay-client"
version = "0.2.0"
dependencies = [
"home",
"log",
"mio",
"rand",
"tungstenite",
"uuid",
]
[[package]]
name = "relay-server"
version = "0.2.0"
dependencies = [
"anyhow",
"axum",
"dashmap",
"futures",
"lazy_static",
"rand",
"sled",
"tokio",
"uuid",
]
[[package]]
@ -3905,6 +3977,22 @@ dependencies = [
"autocfg",
]
[[package]]
name = "sled"
version = "0.34.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935"
dependencies = [
"crc32fast",
"crossbeam-epoch",
"crossbeam-utils",
"fs2",
"fxhash",
"libc",
"log",
"parking_lot 0.11.2",
]
[[package]]
name = "slotmap"
version = "1.0.7"
@ -4596,7 +4684,7 @@ dependencies = [
"js-sys",
"log",
"naga",
"parking_lot",
"parking_lot 0.12.1",
"profiling",
"raw-window-handle",
"smallvec",
@ -4621,7 +4709,7 @@ dependencies = [
"codespan-reporting",
"log",
"naga",
"parking_lot",
"parking_lot 0.12.1",
"profiling",
"raw-window-handle",
"rustc-hash",
@ -4659,7 +4747,7 @@ dependencies = [
"metal",
"naga",
"objc",
"parking_lot",
"parking_lot 0.12.1",
"profiling",
"range-alloc",
"raw-window-handle",

View file

@ -8,4 +8,5 @@ FROM alpine:latest
WORKDIR /app
COPY --from=builder /app/target/release/relay-server .
EXPOSE 80/tcp
VOLUME [ "/data" ]
CMD ["./relay-server"]

View file

@ -14,5 +14,7 @@ workspace = true
[dependencies]
tungstenite = { version = "0.21.0", features = ["rustls-tls-native-roots"] }
mio = { version = "0.8.10", features = ["net", "os-poll"] }
uuid = "1.7.0"
rand = "0.8.5"
home = "0.5.9"
log = "0.4.20"

View file

@ -1,8 +1,10 @@
//! A library containing a client to use a relay server.
use std::borrow::Cow;
use std::fs;
use std::io::{self};
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::PathBuf;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::time::{Duration, Instant};
@ -12,6 +14,7 @@ use rand::seq::SliceRandom;
use tungstenite::handshake::MidHandshake;
use tungstenite::stream::MaybeTlsStream;
use tungstenite::{ClientHandshake, HandshakeError, Message, WebSocket};
use uuid::Uuid;
/// The state of a [Connection].
#[derive(Debug)]
@ -28,6 +31,12 @@ enum ConnectionState {
/// The websocket handshake is in progress.
Handshaking(MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>),
/// The websocket handshake is finished.
Handshaked(WebSocket<MaybeTlsStream<TcpStream>>),
/// The [Connection] is registering with the relay server.
Registering(WebSocket<MaybeTlsStream<TcpStream>>),
/// The [Connection] is connected.
Active(WebSocket<MaybeTlsStream<TcpStream>>),
}
@ -40,6 +49,15 @@ pub struct Connection {
/// The domain of the relay server.
domain: String,
/// The path to the file where the identifier and secret key are stored.
data_path: PathBuf,
/// The identifier of the connection for the relay server.
identifier: Option<Uuid>,
/// The secret key used to authenticate with the relay server.
secret: Option<Uuid>,
/// The receiver part of the send channel.
///
/// This is used in [Connection::update] to get messages that need to
@ -56,13 +74,13 @@ pub struct Connection {
///
/// This is used in [Connection::read] to get messages that have been
/// received from the relay server.
receive_receiver: Receiver<(u32, Vec<u8>)>,
receive_receiver: Receiver<(Uuid, Vec<u8>)>,
/// The sender part of the send channel.
///
/// This is used in [Connection::update] to store messages that have
/// been received from the relay server.
receive_sender: Sender<(u32, Vec<u8>)>,
receive_sender: Sender<(Uuid, Vec<u8>)>,
/// The state of the connection.
state: ConnectionState,
@ -72,11 +90,45 @@ impl Connection {
/// Create a new [Connection].
pub fn new<'a>(domain: impl Into<Cow<'a, str>>) -> io::Result<Self> {
let domain = domain.into();
// Loads the identifier and secret key from disk.
let (data_path, identifier, secret) = {
// Find the relay data file path.
let mut path = home::home_dir().ok_or_else(|| {
io::Error::new(io::ErrorKind::NotFound, "could not find home directory")
})?;
path.push(".relay-data");
// Check if the file exists.
match path.exists() {
true => {
// Read the file and parse the identifier and secret key.
let contents = fs::read(&path)?;
if contents.len() != 32 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid data in .relay-data",
));
}
let identifier = Uuid::from_slice(&contents[..16]).map_err(io::Error::other)?;
let secret = Uuid::from_slice(&contents[16..]).map_err(io::Error::other)?;
(path, Some(identifier), Some(secret))
}
false => (path, None, None),
}
};
// Create the communication channels.
let (send_sender, send_receiver) = channel();
let (receive_sender, receive_receiver) = channel();
// Create the connection and return it.
Ok(Self {
address_list: (domain.as_ref(), 443).to_socket_addrs()?.collect(),
domain: domain.into_owned(),
data_path,
identifier,
secret,
send_receiver,
send_sender,
receive_receiver,
@ -85,15 +137,20 @@ impl Connection {
})
}
/// Get the identifier of the connection.
pub const fn identifier(&self) -> Option<Uuid> {
self.identifier
}
/// Send a message to the target client.
pub fn send(&self, target_id: u32, message: Cow<[u8]>) {
let mut data = message.into_owned();
data.extend_from_slice(&target_id.to_be_bytes());
pub fn send<'a>(&self, target_id: Uuid, message: impl Into<Cow<'a, [u8]>>) {
let mut data = message.into().into_owned();
data.extend_from_slice(target_id.as_bytes());
self.send_sender.send(Message::Binary(data)).ok();
}
/// Receive a message from the target client.
pub fn read(&self) -> Option<(u32, Vec<u8>)> {
/// Receive a message from the relay connection.
pub fn read(&self) -> Option<(Uuid, Vec<u8>)> {
self.receive_receiver.try_recv().ok()
}
@ -151,7 +208,7 @@ impl Connection {
/// Start the websocket handshake.
fn start_handshake(&mut self, stream: TcpStream) -> ConnectionState {
match tungstenite::client_tls(format!("wss://{}", self.domain), stream) {
Ok((socket, _)) => ConnectionState::Active(socket),
Ok((socket, _)) => ConnectionState::Handshaked(socket),
Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake),
Err(HandshakeError::Failure(e)) => {
warn!("handshake failed with the relay server: {e}");
@ -166,7 +223,7 @@ impl Connection {
handshake: MidHandshake<ClientHandshake<MaybeTlsStream<TcpStream>>>,
) -> ConnectionState {
match handshake.handshake() {
Ok((socket, _)) => ConnectionState::Active(socket),
Ok((socket, _)) => ConnectionState::Handshaked(socket),
Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake),
Err(HandshakeError::Failure(e)) => {
warn!("handshake failed with the relay server: {e}");
@ -175,6 +232,77 @@ impl Connection {
}
}
/// Start authentication with the relay server.
fn start_authentication(
&mut self,
mut socket: WebSocket<MaybeTlsStream<TcpStream>>,
) -> ConnectionState {
match (self.identifier, self.secret) {
(Some(identifier), Some(secret)) => {
// Create the authentication message.
let mut data = Vec::with_capacity(32);
data.extend(identifier.as_bytes());
data.extend(secret.as_bytes());
// Send the authentication message.
match socket.send(Message::Binary(data)) {
Ok(()) => ConnectionState::Active(socket),
Err(e) => {
warn!("failed to send authentication message: {e}");
ConnectionState::Disconnected
}
}
}
_ => {
// Send empty authentication message to request a new identifier and secret key.
match socket.send(Message::Binary(vec![])) {
Ok(()) => ConnectionState::Registering(socket),
Err(e) => {
warn!("failed to send registration message: {e}");
ConnectionState::Disconnected
}
}
}
}
}
/// Wait for the registration response.
fn get_registration_response(
&mut self,
mut socket: WebSocket<MaybeTlsStream<TcpStream>>,
) -> ConnectionState {
match socket.read() {
Ok(message) => {
// Check the message length.
let data = message.into_data();
if data.len() != 32 {
warn!("received malformed registration response");
return ConnectionState::Disconnected;
}
// Extract the client identifier and secret.
self.identifier = Some(Uuid::from_slice(&data[..16]).expect("invalid identifier"));
self.secret = Some(Uuid::from_slice(&data[16..]).expect("invalid secret"));
// Save the client identifier and secret.
fs::write(&self.data_path, data).ok();
// Activate the connection.
ConnectionState::Active(socket)
}
Err(tungstenite::Error::Io(ref e))
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::Interrupted =>
{
ConnectionState::Registering(socket)
}
Err(e) => {
warn!("failed to receive registration response: {e}");
ConnectionState::Disconnected
}
}
}
/// Update the [Connection] by receiving and sending messages.
fn update_connection(
&mut self,
@ -203,18 +331,14 @@ impl Connection {
Ok(message) => {
// Check the message length.
let mut data = message.into_data();
if data.len() < 4 {
if data.len() < 16 {
warn!("received malformed message with length: {}", data.len());
continue;
}
// Extract the sender ID.
let id_start = data.len() - 4;
let sender_id = u32::from_be_bytes(
data[id_start..]
.try_into()
.unwrap_or_else(|_| unreachable!()),
);
let id_start = data.len() - 16;
let sender_id = Uuid::from_slice(&data[id_start..]).expect("invalid sender id");
data.truncate(id_start);
// Send the message to the receive channel.
@ -250,6 +374,8 @@ impl Connection {
ConnectionState::Connecting(stream, start) => self.check_connection(stream, start),
ConnectionState::Connected(stream) => self.start_handshake(stream),
ConnectionState::Handshaking(handshake) => self.continue_handshake(handshake),
ConnectionState::Handshaked(socket) => self.start_authentication(socket),
ConnectionState::Registering(socket) => self.get_registration_response(socket),
ConnectionState::Active(socket) => self.update_connection(socket),
}
}

View file

@ -14,7 +14,9 @@ workspace = true
[dependencies]
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] }
axum = { version = "0.7.4", features = ["ws"] }
uuid = { version = "1.7.0", features = ["v4"] }
lazy_static = "1.4.0"
futures = "0.3.30"
dashmap = "5.5.3"
rand = "0.8.5"
anyhow = "1.0.79"
sled = "0.34.7"

View file

@ -1,5 +1,8 @@
//! A relay server for bevnet.
use std::io;
use anyhow::bail;
use axum::extract::ws::{Message, WebSocket};
use axum::extract::WebSocketUpgrade;
use axum::routing::get;
@ -7,19 +10,25 @@ use axum::Router;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use lazy_static::lazy_static;
use rand::Rng;
use sled::transaction::{ConflictableTransactionResult, TransactionalTree};
use sled::{Db, IVec};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
use uuid::Uuid;
lazy_static! {
static ref CLIENTS: DashMap<u32, Sender<Vec<u8>>> = DashMap::new();
static ref CLIENTS: DashMap<Uuid, Sender<Vec<u8>>> = DashMap::new();
static ref DB: Db = sled::open("/data/secrets.db").expect("unable to open the database");
}
#[tokio::main]
async fn main() {
let app = Router::new().route(
"/",
get(|ws: WebSocketUpgrade| async { ws.on_upgrade(handle) }),
get(|ws: WebSocketUpgrade| async {
ws.on_upgrade(|socket| async {
handle(socket).await.ok();
})
}),
);
let listener = tokio::net::TcpListener::bind("0.0.0.0:80")
.await
@ -27,80 +36,125 @@ async fn main() {
axum::serve(listener, app).await.expect("failed to serve");
}
/// Handle the websocket connection.
async fn handle(socket: WebSocket) {
// Generate a new ID for the client.
let client_id: u32 = loop {
let id = rand::thread_rng().gen();
if !CLIENTS.contains_key(&id) {
/// Create a new client and add it to the database.
fn create_client(tx: &TransactionalTree) -> ConflictableTransactionResult<(Uuid, Uuid), io::Error> {
// Generates a new identifier for the client.
let client_id = loop {
// Generates a new random identifier.
let id = Uuid::new_v4();
// Check if the id isn't already in the database.
if tx.get(id.as_bytes())?.is_none() {
break id;
}
};
println!("Client({}) connected", client_id);
// Add the client to the list of connected clients.
let (sender, receiver) = channel(128);
CLIENTS.insert(client_id, sender);
// Generate a random secret for the client.
let secret = Uuid::new_v4();
// Handle messages from the client.
let result = handle_socket(socket, client_id, receiver).await;
// Add the new client to the database.
tx.insert(client_id.as_bytes(), secret.as_bytes())?;
// Remove the client from the list of connected clients.
match result {
Ok(_) => println!("Client({}) disconnected", client_id),
Err(e) => {
CLIENTS.remove(&client_id);
println!("Client({}) disconnected: {}", client_id, e);
}
}
// Returns the client identifier and his secret.
Ok((client_id, secret))
}
/// Error prone part of handling the websocket connection.
async fn handle_socket(
mut socket: WebSocket,
client_id: u32,
mut receiver: Receiver<Vec<u8>>,
) -> Result<(), axum::Error> {
// Send the client ID to the client.
socket
.send(Message::Binary(client_id.to_be_bytes().to_vec()))
.await?;
/// Handle the websocket connection.
async fn handle(mut socket: WebSocket) -> anyhow::Result<()> {
// Receive the first request from the client.
let data = match socket.recv().await {
Some(Ok(message)) => message.into_data(),
_ => return Ok(()),
};
// If the request is empty it means that the client want a new identifier and
// secret, so we create them and send them to the client.
let client_id = if data.is_empty() {
// Generate the new client.
let (client_id, secret) = DB.transaction(create_client)?;
DB.flush_async().await?;
println!("{client_id} created");
// Send the data to the client.
let mut data = Vec::with_capacity(32);
data.extend_from_slice(client_id.as_bytes());
data.extend_from_slice(secret.as_bytes());
socket.send(Message::Binary(data)).await?;
// Returns the client identifier.
client_id
}
// Otherwise it means that the client want to reuse an identifier, so it will
// send it along with his secret to prove that he is the right client.
else {
// Check for the message length to detect malformed messages.
if data.len() != 32 {
bail!("malformed message");
}
// Get the client identifier and secret from the message.
let client_id = Uuid::from_slice(&data[..16])?;
let secret = Uuid::from_slice(&data[16..])?;
// Check with the database if the secret is correct.
if DB.get(client_id.as_bytes())? != Some(IVec::from(secret.as_bytes())) {
bail!("invalid secret")
}
// Returns the client identifier.
client_id
};
// Handle the client connection.
println!("{client_id} connected");
let (sender, receiver) = channel(128);
CLIENTS.insert(client_id, sender);
handle_client(socket, client_id, receiver).await.ok();
CLIENTS.remove(&client_id);
println!("{client_id} disconnected");
// Returns success.
Ok(())
}
/// Handle the client connection.
async fn handle_client(
socket: WebSocket,
client_id: Uuid,
mut receiver: Receiver<Vec<u8>>,
) -> anyhow::Result<()> {
// Split the socket into sender and receiver.
let (mut writer, mut reader) = socket.split();
// Handle sending messages to the client.
let sending_task: JoinHandle<Result<(), axum::Error>> = tokio::spawn(async move {
tokio::spawn(async move {
while let Some(message) = receiver.recv().await {
writer
.send(Message::Binary(message))
.await
.map_err(axum::Error::new)?;
writer.send(Message::Binary(message)).await?;
}
Ok(())
Ok::<(), axum::Error>(())
});
// Handle messages from the client.
while let Some(Ok(message)) = reader.next().await {
// Get the target ID from the message.
let mut data = message.into_data();
let id_start = data.len() - 4;
let target_id = u32::from_be_bytes(data[id_start..].try_into().map_err(axum::Error::new)?);
if data.len() < 16 {
bail!("malformed message");
}
let id_start = data.len() - 16;
let target_id = Uuid::from_slice(&data[id_start..])?;
// Write the sender ID to the message.
for (i, byte) in client_id.to_be_bytes().into_iter().enumerate() {
for (i, &byte) in client_id.as_bytes().iter().enumerate() {
data[id_start + i] = byte;
}
// Send the message to the target client.
if let Some(sender) = CLIENTS.get(&target_id) {
sender.send(data).await.map_err(axum::Error::new)?;
sender.send(data).await?;
}
}
// Remove the client from the list of connected clients.
CLIENTS.remove(&client_id);
// Wait for the sender to finish.
sending_task.await.map_err(axum::Error::new)?
// Returns success.
Ok(())
}