New network system

This commit is contained in:
Tipragot 2023-04-28 11:33:01 +02:00
parent 9f233c6482
commit 942c726b95
2 changed files with 346 additions and 0 deletions

View file

@ -0,0 +1,193 @@
use bevy::{prelude::*, utils::HashMap};
use std::{
io::{self, ErrorKind},
net::{TcpStream, ToSocketAddrs},
sync::{mpsc::TryRecvError, Arc},
};
use tcp::{Connection, Listener, Packet, RawPacket};
mod tcp;
/// A Bevy resource that store the packets handlers for the client.
#[derive(Resource)]
pub struct ClientPacketHandler(Arc<HashMap<u64, Box<dyn Fn(RawPacket, &mut World) + Send + Sync>>>);
/// A connection to a remote server.
#[derive(Resource)]
pub struct ServerConnection(Connection);
impl ServerConnection {
/// Connects to a remote server.
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
Ok(Self(Connection::new(TcpStream::connect(addr)?)?))
}
/// Send a [Packet] to the remote server.
pub fn send<P: Packet>(&self, packet: P) {
self.0.send(packet).ok();
}
}
/// A Bevy resource that store the packets handlers for the server.
#[derive(Resource)]
struct ServerPacketHandler(
Arc<HashMap<u64, Box<dyn Fn(Entity, ClientConnection, RawPacket, &mut World) + Send + Sync>>>,
);
/// A [ClientConnection] listener.
#[derive(Resource)]
pub struct ClientListener(Listener);
impl ClientListener {
/// Creates a new TCP listener on the given address.
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
Ok(Self(Listener::bind(addr)?))
}
}
/// A connection to a remote client.
#[derive(Component)]
pub struct ClientConnection(Arc<Connection>);
impl ClientConnection {
/// Sends a [Packet] to the remote client.
pub fn send<P: Packet>(&self, packet: P) {
self.0.send(packet).ok();
}
}
/// A Bevy system to handle incoming [ClientConnection]s and remove the
/// [ClientListener] resource if it is no longer listening.
fn accept_connections(mut commands: Commands, listener: Option<Res<ClientListener>>) {
if let Some(listener) = listener {
match listener.0.accept() {
Ok(connection) => {
commands.spawn(ClientConnection(Arc::new(connection)));
}
Err(error) => match error.kind() {
ErrorKind::WouldBlock => {}
_ => commands.remove_resource::<ClientListener>(),
},
}
}
}
/// A Bevy system to handle incoming [Packet]s from
/// the [ClientConnection]s and the [ServerConnection].
/// It removes them if they are no longer connected.
fn receive_packets(world: &mut World) {
// Handle client packets
let mut packets = Vec::new();
let mut to_remove = false;
if let Some(connection) = world.get_resource::<ServerConnection>() {
if let Some(receiver) = connection.0.recv() {
loop {
match receiver.try_recv() {
Ok(packet) => packets.push(packet),
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
to_remove = true;
break;
}
}
}
}
}
if let Some(handlers) = world
.get_resource_mut::<ClientPacketHandler>()
.map(|handlers| Arc::clone(&handlers.0))
{
for packet in packets.into_iter() {
if let Some(handler) = handlers.get(&packet.packet_id()) {
(handler)(packet, world);
}
}
}
if to_remove {
world.remove_resource::<ServerConnection>();
}
// Handle server packets
let mut packets = Vec::new();
let mut to_remove = Vec::new();
for (entity, connection) in world.query::<(Entity, &ClientConnection)>().iter(world) {
if let Some(receiver) = connection.0.recv() {
loop {
match receiver.try_recv() {
Ok(packet) => {
packets.push((entity, ClientConnection(Arc::clone(&connection.0)), packet));
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
to_remove.push(entity);
break;
}
}
}
}
}
if let Some(handlers) = world
.get_resource_mut::<ServerPacketHandler>()
.map(|handlers| Arc::clone(&handlers.0))
{
for (entity, connection, packet) in packets.into_iter() {
if let Some(handler) = handlers.get(&packet.packet_id()) {
(handler)(entity, connection, packet, world);
}
}
}
for entity in to_remove {
world.despawn(entity);
}
}
/// A network plugin.
pub struct NetworkPlugin;
impl Plugin for NetworkPlugin {
fn build(&self, app: &mut App) {
app.insert_resource(ServerPacketHandler(Arc::new(HashMap::new())));
app.insert_resource(ClientPacketHandler(Arc::new(HashMap::new())));
app.add_system(accept_connections);
app.add_system(receive_packets);
}
}
/// A extension trait to add packet handler.
pub trait NetworkAppExt {
/// Add a client packet handler.
fn add_client_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(RawPacket, &mut World) + Send + Sync + 'static;
/// Add a server packet handler.
fn add_server_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(Entity, ClientConnection, RawPacket, &mut World) + Send + Sync + 'static;
}
impl NetworkAppExt for App {
fn add_client_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(RawPacket, &mut World) + Send + Sync + 'static,
{
Arc::get_mut(&mut self.world.resource_mut::<ClientPacketHandler>().0)
.unwrap()
.insert(P::packet_id(), Box::new(handler));
self
}
fn add_server_packet_handler<P, H>(&mut self, handler: H) -> &mut Self
where
P: Packet,
H: Fn(Entity, ClientConnection, RawPacket, &mut World) + Send + Sync + 'static,
{
Arc::get_mut(&mut self.world.resource_mut::<ServerPacketHandler>().0)
.unwrap()
.insert(P::packet_id(), Box::new(handler));
self
}
}

153
src/tcp.rs Normal file
View file

@ -0,0 +1,153 @@
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
io::{self, ErrorKind, Read, Write},
net::{Shutdown, TcpListener, TcpStream, ToSocketAddrs},
sync::{
mpsc::{channel, Receiver, Sender},
Mutex, MutexGuard,
},
thread,
};
/// A packet that can be sent over a [Connection].
pub trait Packet: DeserializeOwned + Serialize + Send + Sync {
/// Returns a unique identifier for this packet.
///
/// This function uses [std::any::type_name] to get a string
/// representation of the type of the object and returns the
/// hash of that string. This is not perfect... but I didn't
/// find a better solution.
fn packet_id() -> u64 {
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
hasher.finish()
}
}
impl<T: DeserializeOwned + Serialize + Send + Sync> Packet for T {}
/// A raw packet.
pub struct RawPacket {
/// The identifier for this packet.
packet_id: u64,
/// The serialized packet.
data: Vec<u8>,
}
impl RawPacket {
/// Returns the identifier for this packet.
pub fn packet_id(&self) -> u64 {
self.packet_id
}
/// Deserializes this packet to the given [Packet] type.
pub fn deserialize<P: Packet>(&self) -> Option<P> {
bincode::deserialize(&self.data).ok()
}
}
/// A TCP connection that can send and receive [Packet]s.
pub struct Connection {
/// The [TcpStream] of the connection.
///
/// It is used to send [Packet]s and to stop the receive
/// thread when the [Connection] is dropped.
stream: Mutex<TcpStream>,
/// The [Receiver] of the received [RawPacket]s.
receiver: Mutex<Receiver<RawPacket>>,
}
impl Connection {
/// Creates a new TCP connection.
pub fn new(stream: TcpStream) -> io::Result<Self> {
let (sender, receiver) = channel();
let thread_stream = stream.try_clone()?;
thread::spawn(move || Self::receive_loop(thread_stream, sender));
Ok(Self {
stream: Mutex::new(stream),
receiver: Mutex::new(receiver),
})
}
/// The [Packet] receiving loop.
fn receive_loop(mut stream: TcpStream, sender: Sender<RawPacket>) {
let mut len_buffer = [0; 4];
let mut id_buffer = [0; 8];
loop {
// Read the length of the packet
if stream.read_exact(&mut len_buffer).is_err() {
return;
}
let packet_len = u32::from_le_bytes(len_buffer);
// Read the packet identifier
if stream.read_exact(&mut id_buffer).is_err() {
return;
}
let packet_id = u64::from_le_bytes(id_buffer);
// Read the packet data
let mut data = vec![0; packet_len as usize];
if stream.read_exact(&mut data).is_err() {
return;
}
// Store the packet
if sender.send(RawPacket { packet_id, data }).is_err() {
return;
}
}
}
/// Sends a [Packet] over this connection.
pub fn send<P: Packet>(&self, packet: P) -> io::Result<()> {
let data = bincode::serialize(&packet).map_err(|e| io::Error::new(ErrorKind::Other, e))?;
let mut stream = self
.stream
.lock()
.map_err(|_| io::Error::new(ErrorKind::Other, "Failed to lock stream"))?;
stream.write_all(&data.len().to_le_bytes())?;
stream.write_all(&P::packet_id().to_le_bytes())?;
stream.write_all(&data)
}
/// Gets the [RawPacket] receiver of this connection.
pub fn recv(&self) -> Option<MutexGuard<Receiver<RawPacket>>> {
self.receiver.lock().ok()
}
}
impl Drop for Connection {
fn drop(&mut self) {
self.stream
.lock()
.map(|stream| stream.shutdown(Shutdown::Both))
.ok();
}
}
/// A [Connection] listener.
pub struct Listener {
/// The [TcpListener] of the listener.
listener: TcpListener,
}
impl Listener {
/// Creates a new TCP listener on the given address.
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let listener = TcpListener::bind(addr)?;
listener.set_nonblocking(true)?;
Ok(Self { listener })
}
/// Accepts a new [Connection].
pub fn accept(&self) -> io::Result<Connection> {
self.listener
.accept()
.and_then(|(stream, _)| Connection::new(stream))
}
}