generated from tipragot/rust
Add id to messages and make send and receive non mutable
Some checks failed
Rust Checks / checks (push) Failing after 1m4s
Some checks failed
Rust Checks / checks (push) Failing after 1m4s
This commit is contained in:
parent
33b11b71c3
commit
9f9ac40a13
|
@ -45,6 +45,7 @@
|
||||||
use std::collections::{HashMap, LinkedList};
|
use std::collections::{HashMap, LinkedList};
|
||||||
use std::io::{self, Read, Write};
|
use std::io::{self, Read, Write};
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddrV4, TcpListener, TcpStream};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddrV4, TcpListener, TcpStream};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng};
|
use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng};
|
||||||
use aes_gcm::{Aes128Gcm, Key, Nonce};
|
use aes_gcm::{Aes128Gcm, Key, Nonce};
|
||||||
|
@ -59,7 +60,10 @@ pub struct Connection {
|
||||||
stream: TcpStream,
|
stream: TcpStream,
|
||||||
|
|
||||||
/// Contains the buffers that are not yet being sent.
|
/// Contains the buffers that are not yet being sent.
|
||||||
send_buffers: LinkedList<(usize, Vec<u8>)>,
|
send_buffers: Mutex<LinkedList<(usize, Vec<u8>)>>,
|
||||||
|
|
||||||
|
/// Contains all the received messages associated with their ids.
|
||||||
|
received_messages: Mutex<HashMap<u16, LinkedList<Vec<u8>>>>,
|
||||||
|
|
||||||
/// The length of the next message to be received.
|
/// The length of the next message to be received.
|
||||||
///
|
///
|
||||||
|
@ -88,7 +92,8 @@ impl Connection {
|
||||||
stream.set_nonblocking(true)?;
|
stream.set_nonblocking(true)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
stream,
|
stream,
|
||||||
send_buffers: LinkedList::new(),
|
send_buffers: Mutex::new(LinkedList::new()),
|
||||||
|
received_messages: Mutex::new(HashMap::new()),
|
||||||
receive_message_len: None,
|
receive_message_len: None,
|
||||||
receive_message_nonce: None,
|
receive_message_nonce: None,
|
||||||
receive_filled_len: 0,
|
receive_filled_len: 0,
|
||||||
|
@ -125,16 +130,21 @@ impl Connection {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends a message over the connection.
|
/// Sends a message and its id over the connection.
|
||||||
///
|
|
||||||
/// Returns `true` if the message has been sent directly and `false`
|
|
||||||
/// if the message is still in the send queue.
|
|
||||||
///
|
///
|
||||||
/// This function is not blocking.
|
/// This function is not blocking.
|
||||||
pub fn send(&mut self, message: &[u8]) -> io::Result<bool> {
|
pub fn send(&self, message: &[u8], id: u16) -> io::Result<()> {
|
||||||
|
// Add the id to the message.
|
||||||
|
let mut message_with_id = Vec::with_capacity(message.len() + 2);
|
||||||
|
message_with_id.extend_from_slice(message);
|
||||||
|
message_with_id.extend_from_slice(&id.to_ne_bytes());
|
||||||
|
|
||||||
// Encrypt the message.
|
// Encrypt the message.
|
||||||
let nonce = Aes128Gcm::generate_nonce(OsRng);
|
let nonce = Aes128Gcm::generate_nonce(OsRng);
|
||||||
let message = self.cipher.encrypt(&nonce, message).map_err(|e| {
|
let message = self
|
||||||
|
.cipher
|
||||||
|
.encrypt(&nonce, message_with_id.as_slice())
|
||||||
|
.map_err(|e| {
|
||||||
io::Error::new(
|
io::Error::new(
|
||||||
io::ErrorKind::InvalidData,
|
io::ErrorKind::InvalidData,
|
||||||
format!("failed to encrypt message: {}", e),
|
format!("failed to encrypt message: {}", e),
|
||||||
|
@ -152,26 +162,37 @@ impl Connection {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add a new buffer to the send queue.
|
// Lock the send buffers.
|
||||||
self.send_buffers
|
let mut send_buffers = self
|
||||||
.push_back((0, message_len.to_ne_bytes().to_vec()));
|
.send_buffers
|
||||||
self.send_buffers.push_back((0, nonce.to_vec()));
|
.lock()
|
||||||
self.send_buffers.push_back((0, message));
|
.map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?;
|
||||||
|
|
||||||
// Update the connection.
|
// Add a new buffer to the send queue.
|
||||||
self.update()
|
send_buffers.push_back((0, message_len.to_ne_bytes().to_vec()));
|
||||||
|
send_buffers.push_back((0, nonce.to_vec()));
|
||||||
|
send_buffers.push_back((0, message));
|
||||||
|
drop(send_buffers);
|
||||||
|
|
||||||
|
// Returning success.
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Updates the connection.
|
/// Updates the connection.
|
||||||
///
|
///
|
||||||
/// This function sends any pending messages that have not been sent yet.
|
/// This function sends any pending messages that have not been sent yet and
|
||||||
/// It returns `true` if there is no remaining data to send after updating
|
/// receives any pending messages that have not been received yet.
|
||||||
/// the connection and `false` otherwise.
|
|
||||||
///
|
///
|
||||||
/// This function is not blocking.
|
/// This function is not blocking.
|
||||||
pub fn update(&mut self) -> io::Result<bool> {
|
pub fn update(&mut self) -> io::Result<()> {
|
||||||
// Looping over the send buffers.
|
// Lock the send buffers.
|
||||||
while let Some((offset, buffer)) = self.send_buffers.front_mut() {
|
let send_buffers = self
|
||||||
|
.send_buffers
|
||||||
|
.get_mut()
|
||||||
|
.map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?;
|
||||||
|
|
||||||
|
// Looping over the send buffers to send the messages.
|
||||||
|
while let Some((offset, buffer)) = send_buffers.front_mut() {
|
||||||
// Writing the buffer to the stream.
|
// Writing the buffer to the stream.
|
||||||
match self.stream.write(&buffer[*offset..]) {
|
match self.stream.write(&buffer[*offset..]) {
|
||||||
Ok(n) => *offset += n,
|
Ok(n) => *offset += n,
|
||||||
|
@ -182,10 +203,33 @@ impl Connection {
|
||||||
|
|
||||||
// Removing the buffer if it is fully sent.
|
// Removing the buffer if it is fully sent.
|
||||||
if *offset >= buffer.len() {
|
if *offset >= buffer.len() {
|
||||||
self.send_buffers.pop_front();
|
send_buffers.pop_front();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Receiving as many messages as possible.
|
||||||
|
while let Some(mut message) = self.receive_any()? {
|
||||||
|
// Reading the message id.
|
||||||
|
let id = u16::from_ne_bytes(message[message.len() - 2..].try_into().map_err(|e| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidData,
|
||||||
|
format!("failed to read message id: {}", e),
|
||||||
|
)
|
||||||
|
})?);
|
||||||
|
|
||||||
|
// Removing the id from the message.
|
||||||
|
message.pop();
|
||||||
|
message.pop();
|
||||||
|
|
||||||
|
// Adding the message to the receive queue.
|
||||||
|
self.received_messages
|
||||||
|
.get_mut()
|
||||||
|
.map_err(|e| io::Error::other(format!("failed to lock receive queue: {}", e)))?
|
||||||
|
.entry(id)
|
||||||
|
.or_default()
|
||||||
|
.push_back(message);
|
||||||
|
}
|
||||||
|
|
||||||
// Returning success.
|
// Returning success.
|
||||||
Ok(self.send_buffers.is_empty())
|
Ok(self.send_buffers.is_empty())
|
||||||
}
|
}
|
||||||
|
@ -241,10 +285,12 @@ impl Connection {
|
||||||
|
|
||||||
/// Receives a message from the connection.
|
/// Receives a message from the connection.
|
||||||
///
|
///
|
||||||
|
/// The message should contain the message id at the end of the message.
|
||||||
|
///
|
||||||
/// If no message is available, returns `None`.
|
/// If no message is available, returns `None`.
|
||||||
///
|
///
|
||||||
/// This function is not blocking.
|
/// This function is not blocking.
|
||||||
pub fn receive(&mut self) -> io::Result<Option<Vec<u8>>> {
|
fn receive_any(&mut self) -> io::Result<Option<Vec<u8>>> {
|
||||||
// Receiving the message length.
|
// Receiving the message length.
|
||||||
let message_len = match self.receive_message_len {
|
let message_len = match self.receive_message_len {
|
||||||
Some(message_len) => message_len,
|
Some(message_len) => message_len,
|
||||||
|
@ -306,6 +352,26 @@ impl Connection {
|
||||||
// Returning the message.
|
// Returning the message.
|
||||||
Ok(Some(message))
|
Ok(Some(message))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Receives all the messages with the given id from the connection.
|
||||||
|
///
|
||||||
|
/// This function is not blocking.
|
||||||
|
pub fn receive(&self, id: u16) -> io::Result<LinkedList<Vec<u8>>> {
|
||||||
|
// Locking the received messages.
|
||||||
|
let mut received_messages = self.received_messages.lock().map_err(|e| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::Other,
|
||||||
|
format!("failed to lock received messages: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Getting the messages.
|
||||||
|
let messages = received_messages.remove(&id);
|
||||||
|
drop(received_messages);
|
||||||
|
|
||||||
|
// Returning the received messages.
|
||||||
|
messages.map_or_else(|| Ok(LinkedList::new()), Ok)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A non-blocking tcp listener.
|
/// A non-blocking tcp listener.
|
||||||
|
|
Loading…
Reference in a new issue