Add id to messages and make send and receive non mutable
Some checks failed
Rust Checks / checks (push) Failing after 1m4s

This commit is contained in:
Tipragot 2024-02-10 13:05:15 +01:00
parent 33b11b71c3
commit 9f9ac40a13

View file

@ -45,6 +45,7 @@
use std::collections::{HashMap, LinkedList}; use std::collections::{HashMap, LinkedList};
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddrV4, TcpListener, TcpStream}; use std::net::{IpAddr, Ipv4Addr, SocketAddrV4, TcpListener, TcpStream};
use std::sync::Mutex;
use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng}; use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng};
use aes_gcm::{Aes128Gcm, Key, Nonce}; use aes_gcm::{Aes128Gcm, Key, Nonce};
@ -59,7 +60,10 @@ pub struct Connection {
stream: TcpStream, stream: TcpStream,
/// Contains the buffers that are not yet being sent. /// Contains the buffers that are not yet being sent.
send_buffers: LinkedList<(usize, Vec<u8>)>, send_buffers: Mutex<LinkedList<(usize, Vec<u8>)>>,
/// Contains all the received messages associated with their ids.
received_messages: Mutex<HashMap<u16, LinkedList<Vec<u8>>>>,
/// The length of the next message to be received. /// The length of the next message to be received.
/// ///
@ -88,7 +92,8 @@ impl Connection {
stream.set_nonblocking(true)?; stream.set_nonblocking(true)?;
Ok(Self { Ok(Self {
stream, stream,
send_buffers: LinkedList::new(), send_buffers: Mutex::new(LinkedList::new()),
received_messages: Mutex::new(HashMap::new()),
receive_message_len: None, receive_message_len: None,
receive_message_nonce: None, receive_message_nonce: None,
receive_filled_len: 0, receive_filled_len: 0,
@ -125,16 +130,21 @@ impl Connection {
) )
} }
/// Sends a message over the connection. /// Sends a message and its id over the connection.
///
/// Returns `true` if the message has been sent directly and `false`
/// if the message is still in the send queue.
/// ///
/// This function is not blocking. /// This function is not blocking.
pub fn send(&mut self, message: &[u8]) -> io::Result<bool> { pub fn send(&self, message: &[u8], id: u16) -> io::Result<()> {
// Add the id to the message.
let mut message_with_id = Vec::with_capacity(message.len() + 2);
message_with_id.extend_from_slice(message);
message_with_id.extend_from_slice(&id.to_ne_bytes());
// Encrypt the message. // Encrypt the message.
let nonce = Aes128Gcm::generate_nonce(OsRng); let nonce = Aes128Gcm::generate_nonce(OsRng);
let message = self.cipher.encrypt(&nonce, message).map_err(|e| { let message = self
.cipher
.encrypt(&nonce, message_with_id.as_slice())
.map_err(|e| {
io::Error::new( io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
format!("failed to encrypt message: {}", e), format!("failed to encrypt message: {}", e),
@ -152,26 +162,37 @@ impl Connection {
} }
}; };
// Add a new buffer to the send queue. // Lock the send buffers.
self.send_buffers let mut send_buffers = self
.push_back((0, message_len.to_ne_bytes().to_vec())); .send_buffers
self.send_buffers.push_back((0, nonce.to_vec())); .lock()
self.send_buffers.push_back((0, message)); .map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?;
// Update the connection. // Add a new buffer to the send queue.
self.update() send_buffers.push_back((0, message_len.to_ne_bytes().to_vec()));
send_buffers.push_back((0, nonce.to_vec()));
send_buffers.push_back((0, message));
drop(send_buffers);
// Returning success.
Ok(())
} }
/// Updates the connection. /// Updates the connection.
/// ///
/// This function sends any pending messages that have not been sent yet. /// This function sends any pending messages that have not been sent yet and
/// It returns `true` if there is no remaining data to send after updating /// receives any pending messages that have not been received yet.
/// the connection and `false` otherwise.
/// ///
/// This function is not blocking. /// This function is not blocking.
pub fn update(&mut self) -> io::Result<bool> { pub fn update(&mut self) -> io::Result<()> {
// Looping over the send buffers. // Lock the send buffers.
while let Some((offset, buffer)) = self.send_buffers.front_mut() { let send_buffers = self
.send_buffers
.get_mut()
.map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?;
// Looping over the send buffers to send the messages.
while let Some((offset, buffer)) = send_buffers.front_mut() {
// Writing the buffer to the stream. // Writing the buffer to the stream.
match self.stream.write(&buffer[*offset..]) { match self.stream.write(&buffer[*offset..]) {
Ok(n) => *offset += n, Ok(n) => *offset += n,
@ -182,10 +203,33 @@ impl Connection {
// Removing the buffer if it is fully sent. // Removing the buffer if it is fully sent.
if *offset >= buffer.len() { if *offset >= buffer.len() {
self.send_buffers.pop_front(); send_buffers.pop_front();
} }
} }
// Receiving as many messages as possible.
while let Some(mut message) = self.receive_any()? {
// Reading the message id.
let id = u16::from_ne_bytes(message[message.len() - 2..].try_into().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("failed to read message id: {}", e),
)
})?);
// Removing the id from the message.
message.pop();
message.pop();
// Adding the message to the receive queue.
self.received_messages
.get_mut()
.map_err(|e| io::Error::other(format!("failed to lock receive queue: {}", e)))?
.entry(id)
.or_default()
.push_back(message);
}
// Returning success. // Returning success.
Ok(self.send_buffers.is_empty()) Ok(self.send_buffers.is_empty())
} }
@ -241,10 +285,12 @@ impl Connection {
/// Receives a message from the connection. /// Receives a message from the connection.
/// ///
/// The message should contain the message id at the end of the message.
///
/// If no message is available, returns `None`. /// If no message is available, returns `None`.
/// ///
/// This function is not blocking. /// This function is not blocking.
pub fn receive(&mut self) -> io::Result<Option<Vec<u8>>> { fn receive_any(&mut self) -> io::Result<Option<Vec<u8>>> {
// Receiving the message length. // Receiving the message length.
let message_len = match self.receive_message_len { let message_len = match self.receive_message_len {
Some(message_len) => message_len, Some(message_len) => message_len,
@ -306,6 +352,26 @@ impl Connection {
// Returning the message. // Returning the message.
Ok(Some(message)) Ok(Some(message))
} }
/// Receives all the messages with the given id from the connection.
///
/// This function is not blocking.
pub fn receive(&self, id: u16) -> io::Result<LinkedList<Vec<u8>>> {
// Locking the received messages.
let mut received_messages = self.received_messages.lock().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to lock received messages: {}", e),
)
})?;
// Getting the messages.
let messages = received_messages.remove(&id);
drop(received_messages);
// Returning the received messages.
messages.map_or_else(|| Ok(LinkedList::new()), Ok)
}
} }
/// A non-blocking tcp listener. /// A non-blocking tcp listener.