New packet handeling system

This commit is contained in:
Tipragot 2023-04-28 17:51:41 +02:00
parent 942c726b95
commit 69a7f37518
5 changed files with 418 additions and 93 deletions

67
examples/ping_pong.rs Normal file
View file

@ -0,0 +1,67 @@
use bevnet::{
client::{ClientAppExt, ClientPlugin, ServerConnection},
server::{ClientListener, ServerAppExt, ServerPlugin},
};
use bevy::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Ping;
#[derive(Serialize, Deserialize)]
struct Pong;
fn start_server(mut commands: Commands, keys: Res<Input<KeyCode>>) {
if keys.just_pressed(KeyCode::B) {
println!("Starting server...");
match ClientListener::bind("127.0.0.1:8000") {
Ok(listener) => {
commands.insert_resource(listener);
println!("Server started");
}
Err(e) => println!("Failed to start server: {}", e),
}
}
}
fn connect(mut commands: Commands, keys: Res<Input<KeyCode>>) {
if keys.just_pressed(KeyCode::C) {
println!("Connecting to server...");
match ServerConnection::connect("127.0.0.1:8000") {
Ok(connection) => {
commands.insert_resource(connection);
println!("Connected to server");
}
Err(e) => println!("Failed to connect: {}", e),
}
}
}
fn send_ping(connection: Option<Res<ServerConnection>>, keys: Res<Input<KeyCode>>) {
if keys.just_pressed(KeyCode::S) {
println!("Sending ping...");
if let Some(connection) = connection {
connection.send(Ping);
println!("Ping sent");
}
}
}
fn main() {
App::new()
.add_plugins(DefaultPlugins)
.add_plugin(ServerPlugin)
.add_system(start_server)
.add_server_packet_handler::<Ping, _>(|entity, connection, _, _| {
println!("Received ping from {:?}", entity);
connection.send(Pong);
println!("Sent pong to {:?}", entity);
})
.add_plugin(ClientPlugin)
.add_system(connect)
.add_client_packet_handler::<Pong, _>(|_, _| {
println!("Received pong");
})
.add_system(send_ping)
.run();
}

103
src/client.rs Normal file
View file

@ -0,0 +1,103 @@
use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc};
use bevy::prelude::*;
use crate::{tcp::Connection, Packet};
/// A function that handle a received [Packet] on the client.
pub type ClientPacketHandler = Box<dyn Fn(Vec<u8>, &mut World) + Send + Sync>;
/// A Bevy resource that store the packets handlers for the client.
#[derive(Resource)]
pub struct ClientHandlerManager(Arc<HashMap<u64, ClientPacketHandler>>);
/// A connection to a remote server.
#[derive(Resource)]
pub struct ServerConnection(Connection);
impl ServerConnection {
/// Connects to a remote server.
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
Ok(Self(Connection::connect(addr)?))
}
/// Sends a packet through this connection.
pub fn send<P: Packet>(&self, packet: P) {
let mut data = bincode::serialize(&packet).expect("Failed to serialize packet");
data.extend(P::packet_id().to_be_bytes());
self.0.send(data);
}
}
/// A plugin that manage the network connections for a server.
pub struct ClientPlugin;
impl ClientPlugin {
/// Handles a received [Packet] on the server.
pub fn handle_packets(world: &mut World) {
// Get all received packets
let mut packets = Vec::new();
if let Some(connection) = world.get_resource::<ServerConnection>() {
while let Some(mut packet) = connection.0.recv() {
if packet.len() < 8 {
println!("Invalid packet received: {:?}", packet);
} else {
let id_buffer = packet.split_off(packet.len() - 8);
let packet_id = u64::from_be_bytes(id_buffer.try_into().unwrap());
packets.push((packet_id, packet));
}
}
} else {
return;
}
// Get the packet handlers
let handlers = Arc::clone(&world.resource_mut::<ClientHandlerManager>().0);
// Handle all received packets
for (packet_id, packet) in packets {
if let Some(handler) = handlers.get(&packet_id) {
handler(packet, world);
}
}
}
/// Remove [ServerConnection] if it's disconnected.
pub fn remove_disconnected(mut commands: Commands, connection: Option<Res<ServerConnection>>) {
if let Some(connection) = connection {
if connection.0.closed() {
commands.remove_resource::<ServerConnection>();
}
}
}
}
impl Plugin for ClientPlugin {
fn build(&self, app: &mut App) {
app.insert_resource(ClientHandlerManager(Arc::new(HashMap::new())));
app.add_system(ClientPlugin::handle_packets);
app.add_system(ClientPlugin::remove_disconnected);
}
}
/// An extension to add packet handlers.
pub trait ClientAppExt {
/// Add a new packet handler.
fn add_client_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(Vec<u8>, &mut World) + Send + Sync + 'static;
}
impl ClientAppExt for App {
fn add_client_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(Vec<u8>, &mut World) + Send + Sync + 'static,
{
Arc::get_mut(&mut self.world.resource_mut::<ClientHandlerManager>().0)
.unwrap()
.insert(P::packet_id(), Box::new(handler));
self
}
}

View file

@ -1,4 +1,31 @@
use bevy::{prelude::*, utils::HashMap};
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
pub mod client;
pub mod server;
mod tcp;
/// A packet that can be sent over a [Connection].
pub trait Packet: DeserializeOwned + Serialize + Send + Sync {
/// Returns a unique identifier for this packet.
///
/// This function uses [std::any::type_name] to get a string
/// representation of the type of the object and returns the
/// hash of that string. This is not perfect... but I didn't
/// find a better solution.
fn packet_id() -> u64 {
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
hasher.finish()
}
}
impl<T: DeserializeOwned + Serialize + Send + Sync> Packet for T {}
/* use bevy::{prelude::*, utils::HashMap};
use std::{
io::{self, ErrorKind},
net::{TcpStream, ToSocketAddrs},
@ -115,6 +142,7 @@ fn receive_packets(world: &mut World) {
loop {
match receiver.try_recv() {
Ok(packet) => {
println!("YESSSS");
packets.push((entity, ClientConnection(Arc::clone(&connection.0)), packet));
}
Err(TryRecvError::Empty) => break,
@ -191,3 +219,4 @@ impl NetworkAppExt for App {
self
}
}
*/

127
src/server.rs Normal file
View file

@ -0,0 +1,127 @@
use crate::{
tcp::{Connection, Listener},
Packet,
};
use bevy::prelude::*;
use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc};
/// A function that handle a received [Packet] on the server.
pub type ServerPacketHandler =
Box<dyn Fn(Entity, ClientConnection, Vec<u8>, &mut World) + Send + Sync>;
/// A Bevy resource that store the packets handlers for the server.
#[derive(Resource)]
pub struct ServerHandlerManager(Arc<HashMap<u64, ServerPacketHandler>>);
/// A Bevy resource that listens for incoming [ClientConnection]s.
#[derive(Resource)]
pub struct ClientListener(Listener);
impl ClientListener {
/// Creates a new listener on the given address.
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
Ok(Self(Listener::bind(addr)?))
}
}
/// A connection to a remote client.
#[derive(Component)]
pub struct ClientConnection(Arc<Connection>);
impl ClientConnection {
/// Sends a packet through this connection.
pub fn send<P: Packet>(&self, packet: P) {
let mut data = bincode::serialize(&packet).expect("Failed to serialize packet");
data.extend(P::packet_id().to_be_bytes());
self.0.send(data);
}
}
/// A plugin that manage the network connections for a server.
pub struct ServerPlugin;
impl ServerPlugin {
/// Accept new [ClientConnection]s.
pub fn accept_connections(mut commands: Commands, listener: Option<Res<ClientListener>>) {
if let Some(listener) = listener {
if let Some(connection) = listener.0.accept() {
commands.spawn(ClientConnection(Arc::new(connection)));
}
}
}
/// Handles a received [Packet] on the server.
pub fn handle_packets(world: &mut World) {
// Get all received packets
let mut packets = Vec::new();
for (entity, connection) in world.query::<(Entity, &ClientConnection)>().iter(world) {
while let Some(mut packet) = connection.0.recv() {
if packet.len() < 8 {
println!("Invalid packet received: {:?}", packet);
} else {
let id_buffer = packet.split_off(packet.len() - 8);
let packet_id = u64::from_be_bytes(id_buffer.try_into().unwrap());
packets.push((
entity,
ClientConnection(Arc::clone(&connection.0)),
packet_id,
packet,
));
}
}
}
// Get the packet handlers
let handlers = Arc::clone(&world.resource_mut::<ServerHandlerManager>().0);
// Handle all received packets
for (entity, connection, packet_id, packet) in packets {
if let Some(handler) = handlers.get(&packet_id) {
handler(entity, connection, packet, world);
}
}
}
/// Remove disconnected [ClientConnection]s.
pub fn remove_disconnected(
mut commands: Commands,
connections: Query<(Entity, &ClientConnection)>,
) {
for (entity, connection) in connections.iter() {
if connection.0.closed() {
commands.entity(entity).remove::<ClientConnection>();
}
}
}
}
impl Plugin for ServerPlugin {
fn build(&self, app: &mut App) {
app.insert_resource(ServerHandlerManager(Arc::new(HashMap::new())));
app.add_system(ServerPlugin::accept_connections);
app.add_system(ServerPlugin::handle_packets);
app.add_system(ServerPlugin::remove_disconnected);
}
}
/// An extension to add packet handlers.
pub trait ServerAppExt {
/// Add a new packet handler.
fn add_server_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(Entity, ClientConnection, Vec<u8>, &mut World) + Send + Sync + 'static;
}
impl ServerAppExt for App {
fn add_server_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(Entity, ClientConnection, Vec<u8>, &mut World) + Send + Sync + 'static,
{
Arc::get_mut(&mut self.world.resource_mut::<ServerHandlerManager>().0)
.unwrap()
.insert(P::packet_id(), Box::new(handler));
self
}
}

View file

@ -1,132 +1,130 @@
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
io::{self, ErrorKind, Read, Write},
io::{self, Read, Write},
net::{Shutdown, TcpListener, TcpStream, ToSocketAddrs},
sync::{
atomic::{AtomicBool, Ordering},
mpsc::{channel, Receiver, Sender},
Mutex, MutexGuard,
Arc, Mutex,
},
thread,
};
/// A packet that can be sent over a [Connection].
pub trait Packet: DeserializeOwned + Serialize + Send + Sync {
/// Returns a unique identifier for this packet.
///
/// This function uses [std::any::type_name] to get a string
/// representation of the type of the object and returns the
/// hash of that string. This is not perfect... but I didn't
/// find a better solution.
fn packet_id() -> u64 {
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
hasher.finish()
}
}
impl<T: DeserializeOwned + Serialize + Send + Sync> Packet for T {}
/// A raw packet.
pub struct RawPacket {
/// The identifier for this packet.
packet_id: u64,
/// The serialized packet.
data: Vec<u8>,
}
impl RawPacket {
/// Returns the identifier for this packet.
pub fn packet_id(&self) -> u64 {
self.packet_id
}
/// Deserializes this packet to the given [Packet] type.
pub fn deserialize<P: Packet>(&self) -> Option<P> {
bincode::deserialize(&self.data).ok()
}
}
/// A TCP connection that can send and receive [Packet]s.
/// A non-blocking TCP connection.
pub struct Connection {
/// The [TcpStream] of the connection.
///
/// It is used to send [Packet]s and to stop the receive
/// thread when the [Connection] is dropped.
stream: Mutex<TcpStream>,
/// Track if the connection has been closed.
closed: Arc<AtomicBool>,
/// The [Receiver] of the received [RawPacket]s.
receiver: Mutex<Receiver<RawPacket>>,
/// The underlying TCP stream.
stream: TcpStream,
/// Used to receive packets from the receiving thread.
receiver: Mutex<Receiver<Vec<u8>>>,
/// Used to send packets to the sending thread.
sender: Mutex<Sender<Vec<u8>>>,
}
impl Connection {
/// Creates a new TCP connection.
pub fn new(stream: TcpStream) -> io::Result<Self> {
let (sender, receiver) = channel();
/// Creates a new connection.
fn new(stream: TcpStream) -> io::Result<Self> {
stream.set_nonblocking(false)?;
let closed = Arc::new(AtomicBool::new(false));
// Spawn the receiving thread
let thread_stream = stream.try_clone()?;
thread::spawn(move || Self::receive_loop(thread_stream, sender));
let (thread_sender, receiver) = channel();
let thread_closed = Arc::clone(&closed);
thread::spawn(move || Self::receiving_loop(thread_stream, thread_sender, thread_closed));
// Spawn the sending thread
let thread_stream = stream.try_clone()?;
let (sender, thread_receiver) = channel();
let thread_closed = Arc::clone(&closed);
thread::spawn(move || Self::sending_loop(thread_stream, thread_receiver, thread_closed));
// Return the connection
Ok(Self {
stream: Mutex::new(stream),
closed,
stream,
receiver: Mutex::new(receiver),
sender: Mutex::new(sender),
})
}
/// The [Packet] receiving loop.
fn receive_loop(mut stream: TcpStream, sender: Sender<RawPacket>) {
/// Creates a new connection to the given address.
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
Self::new(TcpStream::connect(addr)?)
}
/// The receiving loop for this connection.
fn receiving_loop(mut stream: TcpStream, sender: Sender<Vec<u8>>, closed: Arc<AtomicBool>) {
let mut len_buffer = [0; 4];
let mut id_buffer = [0; 8];
loop {
// Read the length of the packet
// Read the length of the next packet
if stream.read_exact(&mut len_buffer).is_err() {
return;
break;
}
let packet_len = u32::from_le_bytes(len_buffer);
let len = u32::from_be_bytes(len_buffer);
// Read the packet identifier
if stream.read_exact(&mut id_buffer).is_err() {
return;
}
let packet_id = u64::from_le_bytes(id_buffer);
// Read the packet data
let mut data = vec![0; packet_len as usize];
if stream.read_exact(&mut data).is_err() {
return;
// Read the packet
let mut packet = vec![0; len as usize];
if stream.read_exact(&mut packet).is_err() {
break;
}
// Store the packet
if sender.send(RawPacket { packet_id, data }).is_err() {
return;
// Send the packet
if sender.send(packet).is_err() {
break;
}
}
closed.store(true, Ordering::Relaxed);
}
/// Sends a [Packet] over this connection.
pub fn send<P: Packet>(&self, packet: P) -> io::Result<()> {
let data = bincode::serialize(&packet).map_err(|e| io::Error::new(ErrorKind::Other, e))?;
let mut stream = self
.stream
/// The sending loop for this connection.
fn sending_loop(mut stream: TcpStream, receiver: Receiver<Vec<u8>>, closed: Arc<AtomicBool>) {
loop {
// Get the next packet to send
let packet = match receiver.recv() {
Ok(packet) => packet,
Err(_) => break,
};
// Send the length of the packet
let len_buffer = u32::to_be_bytes(packet.len() as u32);
if stream.write_all(&len_buffer).is_err() {
break;
}
// Send the packet
if stream.write_all(&packet).is_err() {
break;
}
}
closed.store(true, Ordering::Relaxed);
}
/// Returns `true` if the connection has been closed.
pub fn closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
}
/// Returns the next received packet.
pub fn recv(&self) -> Option<Vec<u8>> {
self.receiver
.lock()
.map_err(|_| io::Error::new(ErrorKind::Other, "Failed to lock stream"))?;
stream.write_all(&data.len().to_le_bytes())?;
stream.write_all(&P::packet_id().to_le_bytes())?;
stream.write_all(&data)
.ok()
.and_then(|receiver| receiver.try_recv().ok())
}
/// Gets the [RawPacket] receiver of this connection.
pub fn recv(&self) -> Option<MutexGuard<Receiver<RawPacket>>> {
self.receiver.lock().ok()
/// Sends a packet through this connection.
pub fn send(&self, packet: Vec<u8>) {
self.sender.lock().map(|sender| sender.send(packet)).ok();
}
}
impl Drop for Connection {
fn drop(&mut self) {
self.stream
.lock()
.map(|stream| stream.shutdown(Shutdown::Both))
.ok();
self.stream.shutdown(Shutdown::Both).ok();
}
}
@ -145,9 +143,10 @@ impl Listener {
}
/// Accepts a new [Connection].
pub fn accept(&self) -> io::Result<Connection> {
pub fn accept(&self) -> Option<Connection> {
self.listener
.accept()
.and_then(|(stream, _)| Connection::new(stream))
.ok()
}
}