diff --git a/src/lib.rs b/src/lib.rs index e69de29..906f3fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -0,0 +1,193 @@ +use bevy::{prelude::*, utils::HashMap}; +use std::{ + io::{self, ErrorKind}, + net::{TcpStream, ToSocketAddrs}, + sync::{mpsc::TryRecvError, Arc}, +}; +use tcp::{Connection, Listener, Packet, RawPacket}; + +mod tcp; + +/// A Bevy resource that store the packets handlers for the client. +#[derive(Resource)] +pub struct ClientPacketHandler(Arc>>); + +/// A connection to a remote server. +#[derive(Resource)] +pub struct ServerConnection(Connection); + +impl ServerConnection { + /// Connects to a remote server. + pub fn connect(addr: A) -> io::Result { + Ok(Self(Connection::new(TcpStream::connect(addr)?)?)) + } + + /// Send a [Packet] to the remote server. + pub fn send(&self, packet: P) { + self.0.send(packet).ok(); + } +} + +/// A Bevy resource that store the packets handlers for the server. +#[derive(Resource)] +struct ServerPacketHandler( + Arc>>, +); + +/// A [ClientConnection] listener. +#[derive(Resource)] +pub struct ClientListener(Listener); + +impl ClientListener { + /// Creates a new TCP listener on the given address. + pub fn bind(addr: A) -> io::Result { + Ok(Self(Listener::bind(addr)?)) + } +} + +/// A connection to a remote client. +#[derive(Component)] +pub struct ClientConnection(Arc); + +impl ClientConnection { + /// Sends a [Packet] to the remote client. + pub fn send(&self, packet: P) { + self.0.send(packet).ok(); + } +} + +/// A Bevy system to handle incoming [ClientConnection]s and remove the +/// [ClientListener] resource if it is no longer listening. +fn accept_connections(mut commands: Commands, listener: Option>) { + if let Some(listener) = listener { + match listener.0.accept() { + Ok(connection) => { + commands.spawn(ClientConnection(Arc::new(connection))); + } + Err(error) => match error.kind() { + ErrorKind::WouldBlock => {} + _ => commands.remove_resource::(), + }, + } + } +} + +/// A Bevy system to handle incoming [Packet]s from +/// the [ClientConnection]s and the [ServerConnection]. +/// It removes them if they are no longer connected. +fn receive_packets(world: &mut World) { + // Handle client packets + let mut packets = Vec::new(); + let mut to_remove = false; + if let Some(connection) = world.get_resource::() { + if let Some(receiver) = connection.0.recv() { + loop { + match receiver.try_recv() { + Ok(packet) => packets.push(packet), + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => { + to_remove = true; + break; + } + } + } + } + } + if let Some(handlers) = world + .get_resource_mut::() + .map(|handlers| Arc::clone(&handlers.0)) + { + for packet in packets.into_iter() { + if let Some(handler) = handlers.get(&packet.packet_id()) { + (handler)(packet, world); + } + } + } + if to_remove { + world.remove_resource::(); + } + + // Handle server packets + let mut packets = Vec::new(); + let mut to_remove = Vec::new(); + for (entity, connection) in world.query::<(Entity, &ClientConnection)>().iter(world) { + if let Some(receiver) = connection.0.recv() { + loop { + match receiver.try_recv() { + Ok(packet) => { + packets.push((entity, ClientConnection(Arc::clone(&connection.0)), packet)); + } + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => { + to_remove.push(entity); + break; + } + } + } + } + } + if let Some(handlers) = world + .get_resource_mut::() + .map(|handlers| Arc::clone(&handlers.0)) + { + for (entity, connection, packet) in packets.into_iter() { + if let Some(handler) = handlers.get(&packet.packet_id()) { + (handler)(entity, connection, packet, world); + } + } + } + for entity in to_remove { + world.despawn(entity); + } +} + +/// A network plugin. +pub struct NetworkPlugin; + +impl Plugin for NetworkPlugin { + fn build(&self, app: &mut App) { + app.insert_resource(ServerPacketHandler(Arc::new(HashMap::new()))); + app.insert_resource(ClientPacketHandler(Arc::new(HashMap::new()))); + app.add_system(accept_connections); + app.add_system(receive_packets); + } +} + +/// A extension trait to add packet handler. +pub trait NetworkAppExt { + /// Add a client packet handler. + fn add_client_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(RawPacket, &mut World) + Send + Sync + 'static; + + /// Add a server packet handler. + fn add_server_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Entity, ClientConnection, RawPacket, &mut World) + Send + Sync + 'static; +} + +impl NetworkAppExt for App { + fn add_client_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(RawPacket, &mut World) + Send + Sync + 'static, + { + Arc::get_mut(&mut self.world.resource_mut::().0) + .unwrap() + .insert(P::packet_id(), Box::new(handler)); + self + } + + fn add_server_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Entity, ClientConnection, RawPacket, &mut World) + Send + Sync + 'static, + { + Arc::get_mut(&mut self.world.resource_mut::().0) + .unwrap() + .insert(P::packet_id(), Box::new(handler)); + self + } +} diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..08af2a7 --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,153 @@ +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + io::{self, ErrorKind, Read, Write}, + net::{Shutdown, TcpListener, TcpStream, ToSocketAddrs}, + sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, MutexGuard, + }, + thread, +}; + +/// A packet that can be sent over a [Connection]. +pub trait Packet: DeserializeOwned + Serialize + Send + Sync { + /// Returns a unique identifier for this packet. + /// + /// This function uses [std::any::type_name] to get a string + /// representation of the type of the object and returns the + /// hash of that string. This is not perfect... but I didn't + /// find a better solution. + fn packet_id() -> u64 { + let mut hasher = DefaultHasher::new(); + std::any::type_name::().hash(&mut hasher); + hasher.finish() + } +} + +impl Packet for T {} + +/// A raw packet. +pub struct RawPacket { + /// The identifier for this packet. + packet_id: u64, + + /// The serialized packet. + data: Vec, +} + +impl RawPacket { + /// Returns the identifier for this packet. + pub fn packet_id(&self) -> u64 { + self.packet_id + } + + /// Deserializes this packet to the given [Packet] type. + pub fn deserialize(&self) -> Option

{ + bincode::deserialize(&self.data).ok() + } +} + +/// A TCP connection that can send and receive [Packet]s. +pub struct Connection { + /// The [TcpStream] of the connection. + /// + /// It is used to send [Packet]s and to stop the receive + /// thread when the [Connection] is dropped. + stream: Mutex, + + /// The [Receiver] of the received [RawPacket]s. + receiver: Mutex>, +} + +impl Connection { + /// Creates a new TCP connection. + pub fn new(stream: TcpStream) -> io::Result { + let (sender, receiver) = channel(); + let thread_stream = stream.try_clone()?; + thread::spawn(move || Self::receive_loop(thread_stream, sender)); + Ok(Self { + stream: Mutex::new(stream), + receiver: Mutex::new(receiver), + }) + } + + /// The [Packet] receiving loop. + fn receive_loop(mut stream: TcpStream, sender: Sender) { + let mut len_buffer = [0; 4]; + let mut id_buffer = [0; 8]; + loop { + // Read the length of the packet + if stream.read_exact(&mut len_buffer).is_err() { + return; + } + let packet_len = u32::from_le_bytes(len_buffer); + + // Read the packet identifier + if stream.read_exact(&mut id_buffer).is_err() { + return; + } + let packet_id = u64::from_le_bytes(id_buffer); + + // Read the packet data + let mut data = vec![0; packet_len as usize]; + if stream.read_exact(&mut data).is_err() { + return; + } + + // Store the packet + if sender.send(RawPacket { packet_id, data }).is_err() { + return; + } + } + } + + /// Sends a [Packet] over this connection. + pub fn send(&self, packet: P) -> io::Result<()> { + let data = bincode::serialize(&packet).map_err(|e| io::Error::new(ErrorKind::Other, e))?; + let mut stream = self + .stream + .lock() + .map_err(|_| io::Error::new(ErrorKind::Other, "Failed to lock stream"))?; + stream.write_all(&data.len().to_le_bytes())?; + stream.write_all(&P::packet_id().to_le_bytes())?; + stream.write_all(&data) + } + + /// Gets the [RawPacket] receiver of this connection. + pub fn recv(&self) -> Option>> { + self.receiver.lock().ok() + } +} + +impl Drop for Connection { + fn drop(&mut self) { + self.stream + .lock() + .map(|stream| stream.shutdown(Shutdown::Both)) + .ok(); + } +} + +/// A [Connection] listener. +pub struct Listener { + /// The [TcpListener] of the listener. + listener: TcpListener, +} + +impl Listener { + /// Creates a new TCP listener on the given address. + pub fn bind(addr: A) -> io::Result { + let listener = TcpListener::bind(addr)?; + listener.set_nonblocking(true)?; + Ok(Self { listener }) + } + + /// Accepts a new [Connection]. + pub fn accept(&self) -> io::Result { + self.listener + .accept() + .and_then(|(stream, _)| Connection::new(stream)) + } +}