diff --git a/examples/ping_pong.rs b/examples/ping_pong.rs new file mode 100644 index 0000000..4a666b7 --- /dev/null +++ b/examples/ping_pong.rs @@ -0,0 +1,67 @@ +use bevnet::{ + client::{ClientAppExt, ClientPlugin, ServerConnection}, + server::{ClientListener, ServerAppExt, ServerPlugin}, +}; +use bevy::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +struct Ping; + +#[derive(Serialize, Deserialize)] +struct Pong; + +fn start_server(mut commands: Commands, keys: Res>) { + if keys.just_pressed(KeyCode::B) { + println!("Starting server..."); + match ClientListener::bind("127.0.0.1:8000") { + Ok(listener) => { + commands.insert_resource(listener); + println!("Server started"); + } + Err(e) => println!("Failed to start server: {}", e), + } + } +} + +fn connect(mut commands: Commands, keys: Res>) { + if keys.just_pressed(KeyCode::C) { + println!("Connecting to server..."); + match ServerConnection::connect("127.0.0.1:8000") { + Ok(connection) => { + commands.insert_resource(connection); + println!("Connected to server"); + } + Err(e) => println!("Failed to connect: {}", e), + } + } +} + +fn send_ping(connection: Option>, keys: Res>) { + if keys.just_pressed(KeyCode::S) { + println!("Sending ping..."); + if let Some(connection) = connection { + connection.send(Ping); + println!("Ping sent"); + } + } +} + +fn main() { + App::new() + .add_plugins(DefaultPlugins) + .add_plugin(ServerPlugin) + .add_system(start_server) + .add_server_packet_handler::(|entity, connection, _, _| { + println!("Received ping from {:?}", entity); + connection.send(Pong); + println!("Sent pong to {:?}", entity); + }) + .add_plugin(ClientPlugin) + .add_system(connect) + .add_client_packet_handler::(|_, _| { + println!("Received pong"); + }) + .add_system(send_ping) + .run(); +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..d368e0f --- /dev/null +++ b/src/client.rs @@ -0,0 +1,103 @@ +use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc}; + +use bevy::prelude::*; + +use crate::{tcp::Connection, Packet}; + +/// A function that handle a received [Packet] on the client. +pub type ClientPacketHandler = Box, &mut World) + Send + Sync>; + +/// A Bevy resource that store the packets handlers for the client. +#[derive(Resource)] +pub struct ClientHandlerManager(Arc>); + +/// A connection to a remote server. +#[derive(Resource)] +pub struct ServerConnection(Connection); + +impl ServerConnection { + /// Connects to a remote server. + pub fn connect(addr: A) -> io::Result { + Ok(Self(Connection::connect(addr)?)) + } + + /// Sends a packet through this connection. + pub fn send(&self, packet: P) { + let mut data = bincode::serialize(&packet).expect("Failed to serialize packet"); + data.extend(P::packet_id().to_be_bytes()); + self.0.send(data); + } +} + +/// A plugin that manage the network connections for a server. +pub struct ClientPlugin; + +impl ClientPlugin { + /// Handles a received [Packet] on the server. + pub fn handle_packets(world: &mut World) { + // Get all received packets + let mut packets = Vec::new(); + if let Some(connection) = world.get_resource::() { + while let Some(mut packet) = connection.0.recv() { + if packet.len() < 8 { + println!("Invalid packet received: {:?}", packet); + } else { + let id_buffer = packet.split_off(packet.len() - 8); + let packet_id = u64::from_be_bytes(id_buffer.try_into().unwrap()); + packets.push((packet_id, packet)); + } + } + } else { + return; + } + + // Get the packet handlers + let handlers = Arc::clone(&world.resource_mut::().0); + + // Handle all received packets + for (packet_id, packet) in packets { + if let Some(handler) = handlers.get(&packet_id) { + handler(packet, world); + } + } + } + + /// Remove [ServerConnection] if it's disconnected. + pub fn remove_disconnected(mut commands: Commands, connection: Option>) { + if let Some(connection) = connection { + if connection.0.closed() { + commands.remove_resource::(); + } + } + } +} + +impl Plugin for ClientPlugin { + fn build(&self, app: &mut App) { + app.insert_resource(ClientHandlerManager(Arc::new(HashMap::new()))); + app.add_system(ClientPlugin::handle_packets); + app.add_system(ClientPlugin::remove_disconnected); + } +} + +/// An extension to add packet handlers. +pub trait ClientAppExt { + /// Add a new packet handler. + fn add_client_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Vec, &mut World) + Send + Sync + 'static; +} + +impl ClientAppExt for App { + fn add_client_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Vec, &mut World) + Send + Sync + 'static, + { + Arc::get_mut(&mut self.world.resource_mut::().0) + .unwrap() + .insert(P::packet_id(), Box::new(handler)); + self + } +} diff --git a/src/lib.rs b/src/lib.rs index 906f3fd..bbfe983 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,31 @@ -use bevy::{prelude::*, utils::HashMap}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, +}; + +pub mod client; +pub mod server; +mod tcp; + +/// A packet that can be sent over a [Connection]. +pub trait Packet: DeserializeOwned + Serialize + Send + Sync { + /// Returns a unique identifier for this packet. + /// + /// This function uses [std::any::type_name] to get a string + /// representation of the type of the object and returns the + /// hash of that string. This is not perfect... but I didn't + /// find a better solution. + fn packet_id() -> u64 { + let mut hasher = DefaultHasher::new(); + std::any::type_name::().hash(&mut hasher); + hasher.finish() + } +} + +impl Packet for T {} + +/* use bevy::{prelude::*, utils::HashMap}; use std::{ io::{self, ErrorKind}, net::{TcpStream, ToSocketAddrs}, @@ -115,6 +142,7 @@ fn receive_packets(world: &mut World) { loop { match receiver.try_recv() { Ok(packet) => { + println!("YESSSS"); packets.push((entity, ClientConnection(Arc::clone(&connection.0)), packet)); } Err(TryRecvError::Empty) => break, @@ -191,3 +219,4 @@ impl NetworkAppExt for App { self } } + */ diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..2bced84 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,127 @@ +use crate::{ + tcp::{Connection, Listener}, + Packet, +}; +use bevy::prelude::*; +use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc}; + +/// A function that handle a received [Packet] on the server. +pub type ServerPacketHandler = + Box, &mut World) + Send + Sync>; + +/// A Bevy resource that store the packets handlers for the server. +#[derive(Resource)] +pub struct ServerHandlerManager(Arc>); + +/// A Bevy resource that listens for incoming [ClientConnection]s. +#[derive(Resource)] +pub struct ClientListener(Listener); + +impl ClientListener { + /// Creates a new listener on the given address. + pub fn bind(addr: A) -> io::Result { + Ok(Self(Listener::bind(addr)?)) + } +} + +/// A connection to a remote client. +#[derive(Component)] +pub struct ClientConnection(Arc); + +impl ClientConnection { + /// Sends a packet through this connection. + pub fn send(&self, packet: P) { + let mut data = bincode::serialize(&packet).expect("Failed to serialize packet"); + data.extend(P::packet_id().to_be_bytes()); + self.0.send(data); + } +} + +/// A plugin that manage the network connections for a server. +pub struct ServerPlugin; + +impl ServerPlugin { + /// Accept new [ClientConnection]s. + pub fn accept_connections(mut commands: Commands, listener: Option>) { + if let Some(listener) = listener { + if let Some(connection) = listener.0.accept() { + commands.spawn(ClientConnection(Arc::new(connection))); + } + } + } + + /// Handles a received [Packet] on the server. + pub fn handle_packets(world: &mut World) { + // Get all received packets + let mut packets = Vec::new(); + for (entity, connection) in world.query::<(Entity, &ClientConnection)>().iter(world) { + while let Some(mut packet) = connection.0.recv() { + if packet.len() < 8 { + println!("Invalid packet received: {:?}", packet); + } else { + let id_buffer = packet.split_off(packet.len() - 8); + let packet_id = u64::from_be_bytes(id_buffer.try_into().unwrap()); + packets.push(( + entity, + ClientConnection(Arc::clone(&connection.0)), + packet_id, + packet, + )); + } + } + } + + // Get the packet handlers + let handlers = Arc::clone(&world.resource_mut::().0); + + // Handle all received packets + for (entity, connection, packet_id, packet) in packets { + if let Some(handler) = handlers.get(&packet_id) { + handler(entity, connection, packet, world); + } + } + } + + /// Remove disconnected [ClientConnection]s. + pub fn remove_disconnected( + mut commands: Commands, + connections: Query<(Entity, &ClientConnection)>, + ) { + for (entity, connection) in connections.iter() { + if connection.0.closed() { + commands.entity(entity).remove::(); + } + } + } +} + +impl Plugin for ServerPlugin { + fn build(&self, app: &mut App) { + app.insert_resource(ServerHandlerManager(Arc::new(HashMap::new()))); + app.add_system(ServerPlugin::accept_connections); + app.add_system(ServerPlugin::handle_packets); + app.add_system(ServerPlugin::remove_disconnected); + } +} + +/// An extension to add packet handlers. +pub trait ServerAppExt { + /// Add a new packet handler. + fn add_server_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Entity, ClientConnection, Vec, &mut World) + Send + Sync + 'static; +} + +impl ServerAppExt for App { + fn add_server_packet_handler(&mut self, handler: H) -> &mut Self + where + P: Packet, + H: Fn(Entity, ClientConnection, Vec, &mut World) + Send + Sync + 'static, + { + Arc::get_mut(&mut self.world.resource_mut::().0) + .unwrap() + .insert(P::packet_id(), Box::new(handler)); + self + } +} diff --git a/src/tcp.rs b/src/tcp.rs index 08af2a7..413f287 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -1,132 +1,130 @@ -use serde::{de::DeserializeOwned, Serialize}; use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, - io::{self, ErrorKind, Read, Write}, + io::{self, Read, Write}, net::{Shutdown, TcpListener, TcpStream, ToSocketAddrs}, sync::{ + atomic::{AtomicBool, Ordering}, mpsc::{channel, Receiver, Sender}, - Mutex, MutexGuard, + Arc, Mutex, }, thread, }; -/// A packet that can be sent over a [Connection]. -pub trait Packet: DeserializeOwned + Serialize + Send + Sync { - /// Returns a unique identifier for this packet. - /// - /// This function uses [std::any::type_name] to get a string - /// representation of the type of the object and returns the - /// hash of that string. This is not perfect... but I didn't - /// find a better solution. - fn packet_id() -> u64 { - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - hasher.finish() - } -} - -impl Packet for T {} - -/// A raw packet. -pub struct RawPacket { - /// The identifier for this packet. - packet_id: u64, - - /// The serialized packet. - data: Vec, -} - -impl RawPacket { - /// Returns the identifier for this packet. - pub fn packet_id(&self) -> u64 { - self.packet_id - } - - /// Deserializes this packet to the given [Packet] type. - pub fn deserialize(&self) -> Option

{ - bincode::deserialize(&self.data).ok() - } -} - -/// A TCP connection that can send and receive [Packet]s. +/// A non-blocking TCP connection. pub struct Connection { - /// The [TcpStream] of the connection. - /// - /// It is used to send [Packet]s and to stop the receive - /// thread when the [Connection] is dropped. - stream: Mutex, + /// Track if the connection has been closed. + closed: Arc, - /// The [Receiver] of the received [RawPacket]s. - receiver: Mutex>, + /// The underlying TCP stream. + stream: TcpStream, + + /// Used to receive packets from the receiving thread. + receiver: Mutex>>, + + /// Used to send packets to the sending thread. + sender: Mutex>>, } impl Connection { - /// Creates a new TCP connection. - pub fn new(stream: TcpStream) -> io::Result { - let (sender, receiver) = channel(); + /// Creates a new connection. + fn new(stream: TcpStream) -> io::Result { + stream.set_nonblocking(false)?; + let closed = Arc::new(AtomicBool::new(false)); + + // Spawn the receiving thread let thread_stream = stream.try_clone()?; - thread::spawn(move || Self::receive_loop(thread_stream, sender)); + let (thread_sender, receiver) = channel(); + let thread_closed = Arc::clone(&closed); + thread::spawn(move || Self::receiving_loop(thread_stream, thread_sender, thread_closed)); + + // Spawn the sending thread + let thread_stream = stream.try_clone()?; + let (sender, thread_receiver) = channel(); + let thread_closed = Arc::clone(&closed); + thread::spawn(move || Self::sending_loop(thread_stream, thread_receiver, thread_closed)); + + // Return the connection Ok(Self { - stream: Mutex::new(stream), + closed, + stream, receiver: Mutex::new(receiver), + sender: Mutex::new(sender), }) } - /// The [Packet] receiving loop. - fn receive_loop(mut stream: TcpStream, sender: Sender) { + /// Creates a new connection to the given address. + pub fn connect(addr: A) -> io::Result { + Self::new(TcpStream::connect(addr)?) + } + + /// The receiving loop for this connection. + fn receiving_loop(mut stream: TcpStream, sender: Sender>, closed: Arc) { let mut len_buffer = [0; 4]; - let mut id_buffer = [0; 8]; loop { - // Read the length of the packet + // Read the length of the next packet if stream.read_exact(&mut len_buffer).is_err() { - return; + break; } - let packet_len = u32::from_le_bytes(len_buffer); + let len = u32::from_be_bytes(len_buffer); - // Read the packet identifier - if stream.read_exact(&mut id_buffer).is_err() { - return; - } - let packet_id = u64::from_le_bytes(id_buffer); - - // Read the packet data - let mut data = vec![0; packet_len as usize]; - if stream.read_exact(&mut data).is_err() { - return; + // Read the packet + let mut packet = vec![0; len as usize]; + if stream.read_exact(&mut packet).is_err() { + break; } - // Store the packet - if sender.send(RawPacket { packet_id, data }).is_err() { - return; + // Send the packet + if sender.send(packet).is_err() { + break; } } + closed.store(true, Ordering::Relaxed); } - /// Sends a [Packet] over this connection. - pub fn send(&self, packet: P) -> io::Result<()> { - let data = bincode::serialize(&packet).map_err(|e| io::Error::new(ErrorKind::Other, e))?; - let mut stream = self - .stream + /// The sending loop for this connection. + fn sending_loop(mut stream: TcpStream, receiver: Receiver>, closed: Arc) { + loop { + // Get the next packet to send + let packet = match receiver.recv() { + Ok(packet) => packet, + Err(_) => break, + }; + + // Send the length of the packet + let len_buffer = u32::to_be_bytes(packet.len() as u32); + if stream.write_all(&len_buffer).is_err() { + break; + } + + // Send the packet + if stream.write_all(&packet).is_err() { + break; + } + } + closed.store(true, Ordering::Relaxed); + } + + /// Returns `true` if the connection has been closed. + pub fn closed(&self) -> bool { + self.closed.load(Ordering::Relaxed) + } + + /// Returns the next received packet. + pub fn recv(&self) -> Option> { + self.receiver .lock() - .map_err(|_| io::Error::new(ErrorKind::Other, "Failed to lock stream"))?; - stream.write_all(&data.len().to_le_bytes())?; - stream.write_all(&P::packet_id().to_le_bytes())?; - stream.write_all(&data) + .ok() + .and_then(|receiver| receiver.try_recv().ok()) } - /// Gets the [RawPacket] receiver of this connection. - pub fn recv(&self) -> Option>> { - self.receiver.lock().ok() + /// Sends a packet through this connection. + pub fn send(&self, packet: Vec) { + self.sender.lock().map(|sender| sender.send(packet)).ok(); } } impl Drop for Connection { fn drop(&mut self) { - self.stream - .lock() - .map(|stream| stream.shutdown(Shutdown::Both)) - .ok(); + self.stream.shutdown(Shutdown::Both).ok(); } } @@ -145,9 +143,10 @@ impl Listener { } /// Accepts a new [Connection]. - pub fn accept(&self) -> io::Result { + pub fn accept(&self) -> Option { self.listener .accept() .and_then(|(stream, _)| Connection::new(stream)) + .ok() } }