generated from tipragot/rust
Add id to messages and make send and receive non mutable
Some checks failed
Rust Checks / checks (push) Failing after 1m4s
Some checks failed
Rust Checks / checks (push) Failing after 1m4s
This commit is contained in:
parent
33b11b71c3
commit
9f9ac40a13
|
@ -45,6 +45,7 @@
|
|||
use std::collections::{HashMap, LinkedList};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddrV4, TcpListener, TcpStream};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng};
|
||||
use aes_gcm::{Aes128Gcm, Key, Nonce};
|
||||
|
@ -59,7 +60,10 @@ pub struct Connection {
|
|||
stream: TcpStream,
|
||||
|
||||
/// Contains the buffers that are not yet being sent.
|
||||
send_buffers: LinkedList<(usize, Vec<u8>)>,
|
||||
send_buffers: Mutex<LinkedList<(usize, Vec<u8>)>>,
|
||||
|
||||
/// Contains all the received messages associated with their ids.
|
||||
received_messages: Mutex<HashMap<u16, LinkedList<Vec<u8>>>>,
|
||||
|
||||
/// The length of the next message to be received.
|
||||
///
|
||||
|
@ -88,7 +92,8 @@ impl Connection {
|
|||
stream.set_nonblocking(true)?;
|
||||
Ok(Self {
|
||||
stream,
|
||||
send_buffers: LinkedList::new(),
|
||||
send_buffers: Mutex::new(LinkedList::new()),
|
||||
received_messages: Mutex::new(HashMap::new()),
|
||||
receive_message_len: None,
|
||||
receive_message_nonce: None,
|
||||
receive_filled_len: 0,
|
||||
|
@ -125,16 +130,21 @@ impl Connection {
|
|||
)
|
||||
}
|
||||
|
||||
/// Sends a message over the connection.
|
||||
///
|
||||
/// Returns `true` if the message has been sent directly and `false`
|
||||
/// if the message is still in the send queue.
|
||||
/// Sends a message and its id over the connection.
|
||||
///
|
||||
/// This function is not blocking.
|
||||
pub fn send(&mut self, message: &[u8]) -> io::Result<bool> {
|
||||
pub fn send(&self, message: &[u8], id: u16) -> io::Result<()> {
|
||||
// Add the id to the message.
|
||||
let mut message_with_id = Vec::with_capacity(message.len() + 2);
|
||||
message_with_id.extend_from_slice(message);
|
||||
message_with_id.extend_from_slice(&id.to_ne_bytes());
|
||||
|
||||
// Encrypt the message.
|
||||
let nonce = Aes128Gcm::generate_nonce(OsRng);
|
||||
let message = self.cipher.encrypt(&nonce, message).map_err(|e| {
|
||||
let message = self
|
||||
.cipher
|
||||
.encrypt(&nonce, message_with_id.as_slice())
|
||||
.map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("failed to encrypt message: {}", e),
|
||||
|
@ -152,26 +162,37 @@ impl Connection {
|
|||
}
|
||||
};
|
||||
|
||||
// Add a new buffer to the send queue.
|
||||
self.send_buffers
|
||||
.push_back((0, message_len.to_ne_bytes().to_vec()));
|
||||
self.send_buffers.push_back((0, nonce.to_vec()));
|
||||
self.send_buffers.push_back((0, message));
|
||||
// Lock the send buffers.
|
||||
let mut send_buffers = self
|
||||
.send_buffers
|
||||
.lock()
|
||||
.map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?;
|
||||
|
||||
// Update the connection.
|
||||
self.update()
|
||||
// Add a new buffer to the send queue.
|
||||
send_buffers.push_back((0, message_len.to_ne_bytes().to_vec()));
|
||||
send_buffers.push_back((0, nonce.to_vec()));
|
||||
send_buffers.push_back((0, message));
|
||||
drop(send_buffers);
|
||||
|
||||
// Returning success.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates the connection.
|
||||
///
|
||||
/// This function sends any pending messages that have not been sent yet.
|
||||
/// It returns `true` if there is no remaining data to send after updating
|
||||
/// the connection and `false` otherwise.
|
||||
/// This function sends any pending messages that have not been sent yet and
|
||||
/// receives any pending messages that have not been received yet.
|
||||
///
|
||||
/// This function is not blocking.
|
||||
pub fn update(&mut self) -> io::Result<bool> {
|
||||
// Looping over the send buffers.
|
||||
while let Some((offset, buffer)) = self.send_buffers.front_mut() {
|
||||
pub fn update(&mut self) -> io::Result<()> {
|
||||
// Lock the send buffers.
|
||||
let send_buffers = self
|
||||
.send_buffers
|
||||
.get_mut()
|
||||
.map_err(|e| io::Error::other(format!("failed to lock send buffers: {}", e)))?;
|
||||
|
||||
// Looping over the send buffers to send the messages.
|
||||
while let Some((offset, buffer)) = send_buffers.front_mut() {
|
||||
// Writing the buffer to the stream.
|
||||
match self.stream.write(&buffer[*offset..]) {
|
||||
Ok(n) => *offset += n,
|
||||
|
@ -182,10 +203,33 @@ impl Connection {
|
|||
|
||||
// Removing the buffer if it is fully sent.
|
||||
if *offset >= buffer.len() {
|
||||
self.send_buffers.pop_front();
|
||||
send_buffers.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// Receiving as many messages as possible.
|
||||
while let Some(mut message) = self.receive_any()? {
|
||||
// Reading the message id.
|
||||
let id = u16::from_ne_bytes(message[message.len() - 2..].try_into().map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("failed to read message id: {}", e),
|
||||
)
|
||||
})?);
|
||||
|
||||
// Removing the id from the message.
|
||||
message.pop();
|
||||
message.pop();
|
||||
|
||||
// Adding the message to the receive queue.
|
||||
self.received_messages
|
||||
.get_mut()
|
||||
.map_err(|e| io::Error::other(format!("failed to lock receive queue: {}", e)))?
|
||||
.entry(id)
|
||||
.or_default()
|
||||
.push_back(message);
|
||||
}
|
||||
|
||||
// Returning success.
|
||||
Ok(self.send_buffers.is_empty())
|
||||
}
|
||||
|
@ -241,10 +285,12 @@ impl Connection {
|
|||
|
||||
/// Receives a message from the connection.
|
||||
///
|
||||
/// The message should contain the message id at the end of the message.
|
||||
///
|
||||
/// If no message is available, returns `None`.
|
||||
///
|
||||
/// This function is not blocking.
|
||||
pub fn receive(&mut self) -> io::Result<Option<Vec<u8>>> {
|
||||
fn receive_any(&mut self) -> io::Result<Option<Vec<u8>>> {
|
||||
// Receiving the message length.
|
||||
let message_len = match self.receive_message_len {
|
||||
Some(message_len) => message_len,
|
||||
|
@ -306,6 +352,26 @@ impl Connection {
|
|||
// Returning the message.
|
||||
Ok(Some(message))
|
||||
}
|
||||
|
||||
/// Receives all the messages with the given id from the connection.
|
||||
///
|
||||
/// This function is not blocking.
|
||||
pub fn receive(&self, id: u16) -> io::Result<LinkedList<Vec<u8>>> {
|
||||
// Locking the received messages.
|
||||
let mut received_messages = self.received_messages.lock().map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("failed to lock received messages: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Getting the messages.
|
||||
let messages = received_messages.remove(&id);
|
||||
drop(received_messages);
|
||||
|
||||
// Returning the received messages.
|
||||
messages.map_or_else(|| Ok(LinkedList::new()), Ok)
|
||||
}
|
||||
}
|
||||
|
||||
/// A non-blocking tcp listener.
|
||||
|
|
Loading…
Reference in a new issue