bevnet/src/tcp.rs

234 lines
7 KiB
Rust
Raw Normal View History

2023-04-26 22:39:27 +00:00
use crate::packet::{Packet, PacketReceiver};
use std::{
io::{self, Read, Write},
net::{Shutdown, SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
sync::{
atomic::{AtomicBool, Ordering},
mpsc::{channel, Receiver, Sender},
Arc, Mutex,
},
thread,
};
/// Used to send [Packet] to the sending thread.
type ConnectionSender = Arc<Mutex<Sender<(u32, Vec<u8>)>>>;
/// A TCP [Connection] that can send and receive [Packet].
pub struct Connection {
/// Whether or not the [Connection] is currently connected.
connected: Arc<AtomicBool>,
/// Used to store the received [Packet]s.
packets: Arc<PacketReceiver>,
/// Used to send [Packet] to the sending thread.
sender: ConnectionSender,
/// The [TcpStream] of the [Connection].
stream: TcpStream,
/// The address of the [Connection].
address: SocketAddr,
}
impl Connection {
/// Creates a new [Connection] with the given [TcpStream].
pub fn new(stream: TcpStream) -> io::Result<Self> {
let connected = Arc::new(AtomicBool::new(true));
let packets = Arc::new(PacketReceiver::new());
// Receiving part
let mut thread_stream = stream.try_clone()?;
let thread_packets = Arc::clone(&packets);
let thread_connected = Arc::clone(&connected);
thread::spawn(move || {
let mut int_buffer = [0; 4];
loop {
// Check if the connection is closed
if !thread_connected.load(Ordering::Relaxed) {
return;
}
// Read the length of the packet
if thread_stream.read_exact(&mut int_buffer).is_err() {
break;
}
let len = u32::from_be_bytes(int_buffer);
// Read the packet identifier
if thread_stream.read_exact(&mut int_buffer).is_err() {
break;
}
let id = u32::from_be_bytes(int_buffer);
// Read the packet
let mut buffer = vec![0; len as usize];
if thread_stream.read_exact(&mut buffer).is_err() {
break;
}
// Insert the packet
thread_packets.insert(id, buffer);
}
// Close the connection
thread_connected.store(false, Ordering::Relaxed);
});
// Sending part
let mut thread_stream = stream.try_clone()?;
let (sender, receiver) = channel();
let thread_connected = Arc::clone(&connected);
thread::spawn(move || {
loop {
// Check if the connection is closed
if !thread_connected.load(Ordering::Relaxed) {
return;
}
// Get the data to send
let (id, buffer): (u32, Vec<u8>) = match receiver.recv() {
Ok(data) => data,
Err(_) => break,
};
// Send the length of the data
let len = buffer.len() as u32;
if thread_stream.write_all(&len.to_be_bytes()).is_err() {
break;
}
// Send the packet identifier
if thread_stream.write_all(&id.to_be_bytes()).is_err() {
break;
}
// Send the data
if thread_stream.write_all(&buffer).is_err() {
break;
}
// Flush the stream
if thread_stream.flush().is_err() {
break;
}
}
// Close the connection
thread_connected.store(false, Ordering::Relaxed);
});
Ok(Self {
connected,
packets,
sender: Arc::new(Mutex::new(sender)),
address: stream.peer_addr()?,
stream,
})
}
/// Creates a [Connection] to the given address.
pub fn connect<A: ToSocketAddrs>(address: A) -> io::Result<Self> {
Self::new(TcpStream::connect(address)?)
}
/// Returns whether or not the [Connection] is currently connected.
pub fn connected(&self) -> bool {
self.connected.load(Ordering::Relaxed)
}
/// Clears the [Packet] cache.
pub fn clear(&self) {
self.packets.clear();
}
/// Gets all the received [Packet]s of a certain type.
pub fn recv<P: Packet>(&self) -> Vec<P> {
self.packets.extract()
}
/// Sends the given [Packet] to the [Connection].
/// Does nothing if the [Connection] is closed.
pub fn send<P: Packet>(&self, packet: P) {
let data = bincode::serialize(&packet).expect("Failed to serialize packet");
self.sender
.lock()
.map(|sender| sender.send((P::ID, data)))
.ok();
}
/// Returns the address of the [Connection].
pub fn address(&self) -> SocketAddr {
self.address
}
}
impl Drop for Connection {
fn drop(&mut self) {
self.connected.store(false, Ordering::Relaxed);
self.stream.shutdown(Shutdown::Both).ok();
}
}
/// A TCP [Listener] that can accept [Connection]s.
pub struct Listener {
/// Whether the [Listener] is listening.
listening: Arc<AtomicBool>,
/// The receiving part of the [Listener].
receiver: Arc<Mutex<Receiver<Connection>>>,
/// The address the [Listener] is bound to.
address: SocketAddr,
}
impl Listener {
/// Creates a new [Listener] binded to the given address.
pub fn bind<A: ToSocketAddrs>(address: A) -> io::Result<Self> {
let listener = TcpListener::bind(address)?;
let address = listener.local_addr()?;
let listening = Arc::new(AtomicBool::new(true));
let listening_thread = Arc::clone(&listening);
let (sender, receiver) = channel();
thread::spawn(move || {
for stream in listener.incoming() {
let connection = match stream {
Ok(stream) => match Connection::new(stream) {
Ok(connection) => connection,
Err(_) => break,
},
Err(_) => break,
};
if sender.send(connection).is_err() {
break;
}
}
listening_thread.store(false, Ordering::Relaxed);
});
Ok(Self {
listening,
receiver: Arc::new(Mutex::new(receiver)),
address,
})
}
/// Returns whether or not the [Listener] is listening.
pub fn listening(&self) -> bool {
self.listening.load(Ordering::Relaxed)
}
/// Receives the next [Connection] from the [Listener].
pub fn accept(&self) -> Option<Connection> {
self.receiver
.lock()
.ok()
.and_then(|receiver| receiver.try_recv().ok())
}
/// Returns the address the [Listener] is bound to.
pub fn address(&self) -> SocketAddr {
self.address
}
}