diff --git a/crates/bevnet/src/lib.rs b/crates/bevnet/src/lib.rs index 7f79ceb..c3f05a6 100644 --- a/crates/bevnet/src/lib.rs +++ b/crates/bevnet/src/lib.rs @@ -45,6 +45,7 @@ use std::collections::{HashMap, LinkedList}; use std::io::{self, Read, Write}; use std::net::{IpAddr, Ipv4Addr, SocketAddrV4, TcpListener, TcpStream}; +use std::sync::Mutex; use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng}; use aes_gcm::{Aes128Gcm, Key, Nonce}; @@ -59,7 +60,10 @@ pub struct Connection { stream: TcpStream, /// Contains the buffers that are not yet being sent. - send_buffers: LinkedList<(usize, Vec)>, + send_buffers: Mutex)>>, + + /// Contains all the received messages associated with their ids. + received_messages: Mutex>>>, /// The length of the next message to be received. /// @@ -88,7 +92,8 @@ impl Connection { stream.set_nonblocking(true)?; Ok(Self { stream, - send_buffers: LinkedList::new(), + send_buffers: Mutex::new(LinkedList::new()), + received_messages: Mutex::new(HashMap::new()), receive_message_len: None, receive_message_nonce: None, receive_filled_len: 0, @@ -125,21 +130,26 @@ impl Connection { ) } - /// Sends a message over the connection. - /// - /// Returns `true` if the message has been sent directly and `false` - /// if the message is still in the send queue. + /// Sends a message and its id over the connection. /// /// This function is not blocking. - pub fn send(&mut self, message: &[u8]) -> io::Result { + pub fn send(&self, message: &[u8], id: u16) -> io::Result<()> { + // Add the id to the message. + let mut message_with_id = Vec::with_capacity(message.len() + 2); + message_with_id.extend_from_slice(message); + message_with_id.extend_from_slice(&id.to_ne_bytes()); + // Encrypt the message. let nonce = Aes128Gcm::generate_nonce(OsRng); - let message = self.cipher.encrypt(&nonce, message).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidData, - format!("failed to encrypt message: {}", e), - ) - })?; + let message = self + .cipher + .encrypt(&nonce, message_with_id.as_slice()) + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to encrypt message: {}", e), + ) + })?; // Get the length of the message as a u16. let message_len: u16 = match message.len().try_into() { @@ -152,26 +162,37 @@ impl Connection { } }; - // Add a new buffer to the send queue. - self.send_buffers - .push_back((0, message_len.to_ne_bytes().to_vec())); - self.send_buffers.push_back((0, nonce.to_vec())); - self.send_buffers.push_back((0, message)); + // Lock the send buffers. + let mut send_buffers = self + .send_buffers + .lock() + .map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?; - // Update the connection. - self.update() + // Add a new buffer to the send queue. + send_buffers.push_back((0, message_len.to_ne_bytes().to_vec())); + send_buffers.push_back((0, nonce.to_vec())); + send_buffers.push_back((0, message)); + drop(send_buffers); + + // Returning success. + Ok(()) } /// Updates the connection. /// - /// This function sends any pending messages that have not been sent yet. - /// It returns `true` if there is no remaining data to send after updating - /// the connection and `false` otherwise. + /// This function sends any pending messages that have not been sent yet and + /// receives any pending messages that have not been received yet. /// /// This function is not blocking. - pub fn update(&mut self) -> io::Result { - // Looping over the send buffers. - while let Some((offset, buffer)) = self.send_buffers.front_mut() { + pub fn update(&mut self) -> io::Result<()> { + // Lock the send buffers. + let send_buffers = self + .send_buffers + .get_mut() + .map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?; + + // Looping over the send buffers to send the messages. + while let Some((offset, buffer)) = send_buffers.front_mut() { // Writing the buffer to the stream. match self.stream.write(&buffer[*offset..]) { Ok(n) => *offset += n, @@ -182,10 +203,33 @@ impl Connection { // Removing the buffer if it is fully sent. if *offset >= buffer.len() { - self.send_buffers.pop_front(); + send_buffers.pop_front(); } } + // Receiving as many messages as possible. + while let Some(mut message) = self.receive_any()? { + // Reading the message id. + let id = u16::from_ne_bytes(message[message.len() - 2..].try_into().map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to read message id: {}", e), + ) + })?); + + // Removing the id from the message. + message.pop(); + message.pop(); + + // Adding the message to the receive queue. + self.received_messages + .get_mut() + .map_err(|e| io::Error::other(format!("failed to lock receive queue: {}", e)))? + .entry(id) + .or_default() + .push_back(message); + } + // Returning success. Ok(self.send_buffers.is_empty()) } @@ -241,10 +285,12 @@ impl Connection { /// Receives a message from the connection. /// + /// The message should contain the message id at the end of the message. + /// /// If no message is available, returns `None`. /// /// This function is not blocking. - pub fn receive(&mut self) -> io::Result>> { + fn receive_any(&mut self) -> io::Result>> { // Receiving the message length. let message_len = match self.receive_message_len { Some(message_len) => message_len, @@ -306,6 +352,26 @@ impl Connection { // Returning the message. Ok(Some(message)) } + + /// Receives all the messages with the given id from the connection. + /// + /// This function is not blocking. + pub fn receive(&self, id: u16) -> io::Result>> { + // Locking the received messages. + let mut received_messages = self.received_messages.lock().map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("failed to lock received messages: {}", e), + ) + })?; + + // Getting the messages. + let messages = received_messages.remove(&id); + drop(received_messages); + + // Returning the received messages. + messages.map_or_else(|| Ok(LinkedList::new()), Ok) + } } /// A non-blocking tcp listener.