Make the relay connection Send and Sync (#48)
All checks were successful
Rust Checks / checks (push) Successful in 1m34s

Reviewed-on: corentin/border-wars#48
Reviewed-by: Corentin <solois.corentin@gmail.com>
Co-authored-by: Tipragot <contact@tipragot.fr>
Co-committed-by: Tipragot <contact@tipragot.fr>
This commit is contained in:
Tipragot 2024-02-13 13:09:42 +00:00 committed by Corentin
parent e1a191a539
commit 1c35d2d335

View file

@ -1,11 +1,12 @@
//! A library containing a client to use a relay server. //! A library containing a client to use a relay server.
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::LinkedList;
use std::fs; use std::fs;
use std::io::{self}; use std::io::{self};
use std::net::{SocketAddr, ToSocketAddrs}; use std::net::{SocketAddr, ToSocketAddrs};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::Mutex;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use log::warn; use log::warn;
@ -58,29 +59,8 @@ pub struct Connection {
/// The secret key used to authenticate with the relay server. /// The secret key used to authenticate with the relay server.
secret: Option<Uuid>, secret: Option<Uuid>,
/// The receiver part of the send channel. /// A list of messages that needs to be sent.
/// to_send: Mutex<LinkedList<Message>>,
/// This is used in [Connection::update] to get messages that need to
/// be sent to the relay server.
send_receiver: Receiver<Message>,
/// The sender part of the receive channel.
///
/// This is used in [Connection::send] to store messages that need to
/// be sent to the relay server.
send_sender: Sender<Message>,
/// The receiver part of the receive channel.
///
/// This is used in [Connection::read] to get messages that have been
/// received from the relay server.
receive_receiver: Receiver<(Uuid, Vec<u8>)>,
/// The sender part of the send channel.
///
/// This is used in [Connection::update] to store messages that have
/// been received from the relay server.
receive_sender: Sender<(Uuid, Vec<u8>)>,
/// The state of the connection. /// The state of the connection.
state: ConnectionState, state: ConnectionState,
@ -118,10 +98,6 @@ impl Connection {
} }
}; };
// Create the communication channels.
let (send_sender, send_receiver) = channel();
let (receive_sender, receive_receiver) = channel();
// Create the connection and return it. // Create the connection and return it.
Ok(Self { Ok(Self {
address_list: (domain.as_ref(), 443).to_socket_addrs()?.collect(), address_list: (domain.as_ref(), 443).to_socket_addrs()?.collect(),
@ -129,10 +105,7 @@ impl Connection {
data_path, data_path,
identifier, identifier,
secret, secret,
send_receiver, to_send: Mutex::new(LinkedList::new()),
send_sender,
receive_receiver,
receive_sender,
state: ConnectionState::Disconnected, state: ConnectionState::Disconnected,
}) })
} }
@ -146,12 +119,9 @@ impl Connection {
pub fn send<'a>(&self, target_id: Uuid, message: impl Into<Cow<'a, [u8]>>) { pub fn send<'a>(&self, target_id: Uuid, message: impl Into<Cow<'a, [u8]>>) {
let mut data = message.into().into_owned(); let mut data = message.into().into_owned();
data.extend_from_slice(target_id.as_bytes()); data.extend_from_slice(target_id.as_bytes());
self.send_sender.send(Message::Binary(data)).ok(); if let Ok(mut to_send) = self.to_send.lock() {
} to_send.push_back(Message::binary(data));
}
/// Receive a message from the relay connection.
pub fn read(&self) -> Option<(Uuid, Vec<u8>)> {
self.receive_receiver.try_recv().ok()
} }
/// Create a new [TcpStream] to the relay server. /// Create a new [TcpStream] to the relay server.
@ -307,9 +277,16 @@ impl Connection {
fn update_connection( fn update_connection(
&mut self, &mut self,
mut socket: WebSocket<MaybeTlsStream<TcpStream>>, mut socket: WebSocket<MaybeTlsStream<TcpStream>>,
messages: &mut LinkedList<(Uuid, Vec<u8>)>,
) -> ConnectionState { ) -> ConnectionState {
// Unlock the sending list.
let Ok(mut to_send) = self.to_send.lock() else {
warn!("sending list closed");
return ConnectionState::Disconnected;
};
// Send messages from the send channel to the socket. // Send messages from the send channel to the socket.
while let Ok(message) = self.send_receiver.try_recv() { while let Some(message) = to_send.pop_front() {
match socket.send(message) { match socket.send(message) {
Ok(()) => (), Ok(()) => (),
Err(tungstenite::Error::Io(ref e)) Err(tungstenite::Error::Io(ref e))
@ -341,8 +318,8 @@ impl Connection {
let sender_id = Uuid::from_slice(&data[id_start..]).expect("invalid sender id"); let sender_id = Uuid::from_slice(&data[id_start..]).expect("invalid sender id");
data.truncate(id_start); data.truncate(id_start);
// Send the message to the receive channel. // Add the message to the message list.
self.receive_sender.send((sender_id, data)).ok(); messages.push_back((sender_id, data));
} }
Err(tungstenite::Error::Io(ref e)) Err(tungstenite::Error::Io(ref e))
if e.kind() == std::io::ErrorKind::WouldBlock if e.kind() == std::io::ErrorKind::WouldBlock
@ -361,14 +338,15 @@ impl Connection {
ConnectionState::Active(socket) ConnectionState::Active(socket)
} }
/// Update the [Connection]. /// Update the [Connection] and return the received messages.
/// ///
/// This function will connect to the relay server if it's not already /// This function will connect to the relay server if it's not already
/// connected, and will send and receive messages from the relay server /// connected, and will send and receive messages from the relay server
/// if it's connected. /// if it's connected.
/// ///
/// This function will not block the current thread. /// This function will not block the current thread.
pub fn update(&mut self) { pub fn update(&mut self) -> LinkedList<(Uuid, Vec<u8>)> {
let mut messages = LinkedList::new();
self.state = match std::mem::replace(&mut self.state, ConnectionState::Disconnected) { self.state = match std::mem::replace(&mut self.state, ConnectionState::Disconnected) {
ConnectionState::Disconnected => self.create_stream(), ConnectionState::Disconnected => self.create_stream(),
ConnectionState::Connecting(stream, start) => self.check_connection(stream, start), ConnectionState::Connecting(stream, start) => self.check_connection(stream, start),
@ -376,7 +354,8 @@ impl Connection {
ConnectionState::Handshaking(handshake) => self.continue_handshake(handshake), ConnectionState::Handshaking(handshake) => self.continue_handshake(handshake),
ConnectionState::Handshaked(socket) => self.start_authentication(socket), ConnectionState::Handshaked(socket) => self.start_authentication(socket),
ConnectionState::Registering(socket) => self.get_registration_response(socket), ConnectionState::Registering(socket) => self.get_registration_response(socket),
ConnectionState::Active(socket) => self.update_connection(socket), ConnectionState::Active(socket) => self.update_connection(socket, &mut messages),
} };
messages
} }
} }