From a69973b3fae36ac9420367451ffa2dfdc27c68f9 Mon Sep 17 00:00:00 2001 From: Tipragot Date: Sat, 29 Apr 2023 01:29:58 +0200 Subject: [PATCH] Put base system in the same file using features --- Cargo.toml | 6 ++ src/lib.rs | 206 ++++++++++++++++++++++++++++++++++++++++++++++++++++- src/tcp.rs | 16 ++--- 3 files changed, 218 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9128e5b..ff61668 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,9 @@ repository = "https://git.tipragot.fr/tipragot/bevnet" serde = { version = "1.0.160", features = ["derive"] } bincode = "1.3.3" bevy = "0.10.1" + +[features] +default = ["server", "sync"] +server = [] +client = [] +sync = [] diff --git a/src/lib.rs b/src/lib.rs index 817daed..51a7a32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,11 @@ +use bevy::prelude::*; use serde::{de::DeserializeOwned, Serialize}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, }; +use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc}; -pub mod client; -pub mod server; mod tcp; /// A packet that can be sent over a [Connection]. @@ -24,3 +24,205 @@ pub trait Packet: DeserializeOwned + Serialize + Send + Sync { } impl Packet for T {} + +#[cfg(feature = "server")] +/// A function that handle a received [Packet]s. +pub type PacketHandler = Box, &mut World) + Send + Sync>; + +#[cfg(feature = "client")] +/// A function that handle a received [Packet]s. +pub type PacketHandler = Box, &mut World) + Send + Sync>; + +/// A Bevy resource that store the packets handlers. +#[derive(Resource)] +struct HandlerManager(Arc>); + +#[cfg(feature = "server")] +/// A Bevy resource that listens for incoming [Connection]s. +#[derive(Resource)] +pub struct Listener(tcp::Listener); + +#[cfg(feature = "server")] +impl Listener { + /// Creates a new listener on the given address. + pub fn bind(addr: A) -> io::Result { + Ok(Self(tcp::Listener::bind(addr)?)) + } +} + +#[cfg(feature = "server")] +/// A connection to a remote client. +#[derive(Component)] +pub struct Connection(Arc); + +#[cfg(feature = "client")] +/// A connection to a remote server. +#[derive(Resource)] +pub struct Connection(tcp::Connection); + +impl Connection { + #[cfg(feature = "client")] + /// Connects to a remote server. + pub fn connect(addr: A) -> io::Result { + Ok(Self(tcp::Connection::connect(addr)?)) + } + + /// Sends a packet through this connection. + pub fn send(&self, packet: &P) { + let mut data = bincode::serialize(packet).expect("Failed to serialize packet"); + data.extend(P::packet_id().to_be_bytes()); + self.0.send(data); + } +} + +/// A plugin that manage the network connections. +pub struct NetworkPlugin; + +impl NetworkPlugin { + #[cfg(feature = "server")] + /// Accept new [Connection]s. + fn accept_connections(mut commands: Commands, listener: Option>) { + if let Some(listener) = listener { + if let Some(connection) = listener.0.accept() { + commands.spawn(Connection(Arc::new(connection))); + } + } + } + + #[cfg(feature = "server")] + /// Handles a received [Packet]s. + fn handle_packets(world: &mut World) { + // Get all received packets + let mut packets = Vec::new(); + for (entity, connection) in world.query::<(Entity, &Connection)>().iter(world) { + while let Some(mut packet) = connection.0.recv() { + if packet.len() < 8 { + println!("Invalid packet received: {:?}", packet); + } else { + let id_buffer = packet.split_off(packet.len() - 8); + let packet_id = u64::from_be_bytes(id_buffer.try_into().unwrap()); + packets.push(( + entity, + Connection(Arc::clone(&connection.0)), + packet_id, + packet, + )); + } + } + } + + // Get the packet handlers + let handlers = Arc::clone(&world.resource_mut::().0); + + // Handle all received packets + for (entity, connection, packet_id, packet) in packets { + if let Some(handler) = handlers.get(&packet_id) { + handler(entity, connection, packet, world); + } + } + } + + #[cfg(feature = "client")] + /// Handles a received [Packet] on the server. + pub fn handle_packets(world: &mut World) { + // Get all received packets + let mut packets = Vec::new(); + if let Some(connection) = world.get_resource::() { + while let Some(mut packet) = connection.0.recv() { + if packet.len() < 8 { + println!("Invalid packet received: {:?}", packet); + } else { + let id_buffer = packet.split_off(packet.len() - 8); + let packet_id = u64::from_be_bytes(id_buffer.try_into().unwrap()); + packets.push((packet_id, packet)); + } + } + } else { + return; + } + + // Get the packet handlers + let handlers = Arc::clone(&world.resource_mut::().0); + + // Handle all received packets + for (packet_id, packet) in packets { + if let Some(handler) = handlers.get(&packet_id) { + handler(packet, world); + } + } + } + + #[cfg(feature = "server")] + /// Remove disconnected [Connection]s. + fn remove_disconnected(mut commands: Commands, connections: Query<(Entity, &Connection)>) { + for (entity, connection) in connections.iter() { + if connection.0.closed() { + commands.entity(entity).remove::(); + } + } + } + + #[cfg(feature = "client")] + /// Remove [ServerConnection] if it's disconnected. + pub fn remove_disconnected(mut commands: Commands, connection: Option>) { + if let Some(connection) = connection { + if connection.0.closed() { + commands.remove_resource::(); + } + } + } +} + +impl Plugin for NetworkPlugin { + fn build(&self, app: &mut App) { + app.insert_resource(HandlerManager(Arc::new(HashMap::new()))); + app.add_system(NetworkPlugin::handle_packets); + app.add_system(NetworkPlugin::remove_disconnected); + + #[cfg(feature = "server")] + app.add_system(NetworkPlugin::accept_connections); + } +} + +/// An extension to add packet handlers. +pub trait NetworkExt { + #[cfg(feature = "server")] + /// Add a new packet handler. + fn add_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Entity, Connection, Vec, &mut World) + Send + Sync + 'static; + + #[cfg(feature = "client")] + /// Add a new packet handler. + fn add_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Vec, &mut World) + Send + Sync + 'static; +} + +impl NetworkExt for App { + #[cfg(feature = "server")] + fn add_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Entity, Connection, Vec, &mut World) + Send + Sync + 'static, + { + Arc::get_mut(&mut self.world.resource_mut::().0) + .unwrap() + .insert(P::packet_id(), Box::new(handler)); + self + } + + #[cfg(feature = "client")] + fn add_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Vec, &mut World) + Send + Sync + 'static, + { + Arc::get_mut(&mut self.world.resource_mut::().0) + .unwrap() + .insert(P::packet_id(), Box::new(handler)); + self + } +} diff --git a/src/tcp.rs b/src/tcp.rs index 413f287..d6611bf 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -1,6 +1,6 @@ use std::{ io::{self, Read, Write}, - net::{Shutdown, TcpListener, TcpStream, ToSocketAddrs}, + net::{Shutdown, TcpStream, ToSocketAddrs}, sync::{ atomic::{AtomicBool, Ordering}, mpsc::{channel, Receiver, Sender}, @@ -51,6 +51,7 @@ impl Connection { }) } + #[cfg(feature = "client")] /// Creates a new connection to the given address. pub fn connect(addr: A) -> io::Result { Self::new(TcpStream::connect(addr)?) @@ -82,13 +83,7 @@ impl Connection { /// The sending loop for this connection. fn sending_loop(mut stream: TcpStream, receiver: Receiver>, closed: Arc) { - loop { - // Get the next packet to send - let packet = match receiver.recv() { - Ok(packet) => packet, - Err(_) => break, - }; - + while let Ok(packet) = receiver.recv() { // Send the length of the packet let len_buffer = u32::to_be_bytes(packet.len() as u32); if stream.write_all(&len_buffer).is_err() { @@ -128,12 +123,17 @@ impl Drop for Connection { } } +#[cfg(feature = "server")] +use std::net::TcpListener; + +#[cfg(feature = "server")] /// A [Connection] listener. pub struct Listener { /// The [TcpListener] of the listener. listener: TcpListener, } +#[cfg(feature = "server")] impl Listener { /// Creates a new TCP listener on the given address. pub fn bind(addr: A) -> io::Result {