From 82a77dbdbbf6a75e212c182182246c577a7802c7 Mon Sep 17 00:00:00 2001 From: Tipragot Date: Mon, 12 Feb 2024 14:31:54 +0000 Subject: [PATCH] Non blocking relay connection (#44) Reviewed-on: https://git.tipragot.fr/corentin/border-wars/pulls/44 Reviewed-by: Corentin Co-authored-by: Tipragot Co-committed-by: Tipragot --- Cargo.lock | 137 ++++++++++++++++++ crates/relay-client/Cargo.toml | 18 +++ crates/relay-client/src/lib.rs | 256 +++++++++++++++++++++++++++++++++ 3 files changed, 411 insertions(+) create mode 100644 crates/relay-client/Cargo.toml create mode 100644 crates/relay-client/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 9e70aac..0198598 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3285,6 +3285,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "orbclient" version = "0.3.47" @@ -3605,6 +3611,16 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "relay-client" +version = "0.2.0" +dependencies = [ + "log", + "mio", + "rand", + "tungstenite", +] + [[package]] name = "relay-server" version = "0.2.0" @@ -3623,6 +3639,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "216080ab382b992234dda86873c18d4c48358f5cfcb70fd693d7f6f2131b628b" +[[package]] +name = "ring" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.48.0", +] + [[package]] name = "rodio" version = "0.17.3" @@ -3657,6 +3687,60 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustls" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4" +dependencies = [ + "base64 0.21.7", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a716eb65e3158e90e17cd93d855216e27bde02745ab842f2cab4a39dba1bacf" + +[[package]] +name = "rustls-webpki" +version = "0.102.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -3689,12 +3773,44 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.196" @@ -3826,6 +3942,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spirv" version = "0.2.0+1.5.4" @@ -4212,6 +4334,9 @@ dependencies = [ "httparse", "log", "rand", + "rustls", + "rustls-native-certs", + "rustls-pki-types", "sha1", "thiserror", "url", @@ -4277,6 +4402,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.0" @@ -4985,3 +5116,9 @@ dependencies = [ "quote", "syn 2.0.48", ] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/crates/relay-client/Cargo.toml b/crates/relay-client/Cargo.toml new file mode 100644 index 0000000..3df6c3a --- /dev/null +++ b/crates/relay-client/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "relay-client" +version = "0.2.0" +edition = "2021" +license = "GPL-3.0-or-later" +description = "A client to use a relay server." +authors = ["Tipragot "] +keywords = ["bevy", "network", "game"] +categories = ["network-programming", "game-development"] + +[lints] +workspace = true + +[dependencies] +tungstenite = { version = "0.21.0", features = ["rustls-tls-native-roots"] } +mio = { version = "0.8.10", features = ["net", "os-poll"] } +rand = "0.8.5" +log = "0.4.20" diff --git a/crates/relay-client/src/lib.rs b/crates/relay-client/src/lib.rs new file mode 100644 index 0000000..ffb9fb5 --- /dev/null +++ b/crates/relay-client/src/lib.rs @@ -0,0 +1,256 @@ +//! A library containing a client to use a relay server. + +use std::borrow::Cow; +use std::io::{self}; +use std::net::{SocketAddr, ToSocketAddrs}; +use std::sync::mpsc::{channel, Receiver, Sender}; +use std::time::{Duration, Instant}; + +use log::warn; +use mio::net::TcpStream; +use rand::seq::SliceRandom; +use tungstenite::handshake::MidHandshake; +use tungstenite::stream::MaybeTlsStream; +use tungstenite::{ClientHandshake, HandshakeError, Message, WebSocket}; + +/// The state of a [Connection]. +#[derive(Debug)] +enum ConnectionState { + /// The [Connection] is not connected. + Disconnected, + + /// The underlying [TcpStream] is connecting. + Connecting(TcpStream, Instant), + + /// The underlying [TcpStream] is connected. + Connected(TcpStream), + + /// The websocket handshake is in progress. + Handshaking(MidHandshake>>), + + /// The [Connection] is connected. + Active(WebSocket>), +} + +/// A connection to a relay server. +pub struct Connection { + /// The address list corresponding to the relay server. + address_list: Vec, + + /// The domain of the relay server. + domain: String, + + /// The receiver part of the send channel. + /// + /// This is used in [Connection::update] to get messages that need to + /// be sent to the relay server. + send_receiver: Receiver, + + /// The sender part of the receive channel. + /// + /// This is used in [Connection::send] to store messages that need to + /// be sent to the relay server. + send_sender: Sender, + + /// The receiver part of the receive channel. + /// + /// This is used in [Connection::read] to get messages that have been + /// received from the relay server. + receive_receiver: Receiver<(u32, Vec)>, + + /// The sender part of the send channel. + /// + /// This is used in [Connection::update] to store messages that have + /// been received from the relay server. + receive_sender: Sender<(u32, Vec)>, + + /// The state of the connection. + state: ConnectionState, +} + +impl Connection { + /// Create a new [Connection]. + pub fn new<'a>(domain: impl Into>) -> io::Result { + let domain = domain.into(); + let (send_sender, send_receiver) = channel(); + let (receive_sender, receive_receiver) = channel(); + Ok(Self { + address_list: (domain.as_ref(), 443).to_socket_addrs()?.collect(), + domain: domain.into_owned(), + send_receiver, + send_sender, + receive_receiver, + receive_sender, + state: ConnectionState::Disconnected, + }) + } + + /// Send a message to the target client. + pub fn send(&self, target_id: u32, message: Cow<[u8]>) { + let mut data = message.into_owned(); + data.extend_from_slice(&target_id.to_be_bytes()); + self.send_sender.send(Message::Binary(data)).ok(); + } + + /// Receive a message from the target client. + pub fn read(&self) -> Option<(u32, Vec)> { + self.receive_receiver.try_recv().ok() + } + + /// Create a new [TcpStream] to the relay server. + fn create_stream(&mut self) -> ConnectionState { + // Take a random relay address. + let Some(address) = self.address_list.choose(&mut rand::thread_rng()) else { + warn!("no relay address available"); + return ConnectionState::Disconnected; + }; + + // Create the new TCP stream. + match TcpStream::connect(address.to_owned()) { + Ok(stream) => ConnectionState::Connecting(stream, Instant::now()), + Err(e) => { + warn!("failed to start connection to the relay server: {e}"); + ConnectionState::Disconnected + } + } + } + + /// Check if the [TcpStream] of the [Connection] is connected. + fn check_connection(&mut self, stream: TcpStream, start: Instant) -> ConnectionState { + // Check for connection errors. + if let Err(e) = stream.take_error() { + warn!("failed to connect to the relay server: {e}"); + return ConnectionState::Disconnected; + } + + // Check if the stream is connected. + let connected = match stream.peek(&mut [0]) { + Ok(_) => true, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, + Err(ref e) if e.kind() == io::ErrorKind::NotConnected => false, + Err(e) => { + warn!("failed to connect to the relay server: {e}"); + return ConnectionState::Disconnected; + } + }; + + // Check if the connection has timed out. + let elapsed = start.elapsed(); + if elapsed > Duration::from_secs(5) { + warn!("connection to the relay server timed out"); + return ConnectionState::Disconnected; + } + + // Update the connection state if connected. + match connected { + true => ConnectionState::Connected(stream), + false => ConnectionState::Connecting(stream, start), + } + } + + /// Start the websocket handshake. + fn start_handshake(&mut self, stream: TcpStream) -> ConnectionState { + match tungstenite::client_tls(format!("wss://{}", self.domain), stream) { + Ok((socket, _)) => ConnectionState::Active(socket), + Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake), + Err(HandshakeError::Failure(e)) => { + warn!("handshake failed with the relay server: {e}"); + ConnectionState::Disconnected + } + } + } + + /// Continue the websocket handshake. + fn continue_handshake( + &mut self, + handshake: MidHandshake>>, + ) -> ConnectionState { + match handshake.handshake() { + Ok((socket, _)) => ConnectionState::Active(socket), + Err(HandshakeError::Interrupted(handshake)) => ConnectionState::Handshaking(handshake), + Err(HandshakeError::Failure(e)) => { + warn!("handshake failed with the relay server: {e}"); + ConnectionState::Disconnected + } + } + } + + /// Update the [Connection] by receiving and sending messages. + fn update_connection( + &mut self, + mut socket: WebSocket>, + ) -> ConnectionState { + // Send messages from the send channel to the socket. + while let Ok(message) = self.send_receiver.try_recv() { + match socket.send(message) { + Ok(()) => (), + Err(tungstenite::Error::Io(ref e)) + if e.kind() == std::io::ErrorKind::WouldBlock + || e.kind() == std::io::ErrorKind::Interrupted => + { + break; + } + Err(e) => { + warn!("relay connection closed: {e}"); + return ConnectionState::Disconnected; + } + } + } + + // Receive messages from the socket and send them to the receive channel. + loop { + match socket.read() { + Ok(message) => { + // Check the message length. + let mut data = message.into_data(); + if data.len() < 4 { + warn!("received malformed message with length: {}", data.len()); + continue; + } + + // Extract the sender ID. + let id_start = data.len() - 4; + let sender_id = u32::from_be_bytes( + data[id_start..] + .try_into() + .unwrap_or_else(|_| unreachable!()), + ); + data.truncate(id_start); + + // Send the message to the receive channel. + self.receive_sender.send((sender_id, data)).ok(); + } + Err(tungstenite::Error::Io(ref e)) + if e.kind() == std::io::ErrorKind::WouldBlock + || e.kind() == std::io::ErrorKind::Interrupted => + { + break; + } + Err(e) => { + warn!("relay connection closed: {e}"); + return ConnectionState::Disconnected; + } + } + } + + // Keep the connection connected. + ConnectionState::Active(socket) + } + + /// Update the [Connection]. + /// + /// This function will connect to the relay server if it's not already + /// connected, and will send and receive messages from the relay server + /// if it's connected. + /// + /// This function will not block the current thread. + pub fn update(&mut self) { + self.state = match std::mem::replace(&mut self.state, ConnectionState::Disconnected) { + ConnectionState::Disconnected => self.create_stream(), + ConnectionState::Connecting(stream, start) => self.check_connection(stream, start), + ConnectionState::Connected(stream) => self.start_handshake(stream), + ConnectionState::Handshaking(handshake) => self.continue_handshake(handshake), + ConnectionState::Active(socket) => self.update_connection(socket), + } + } +}