Refactor RaftConnection struct to use RaftConnectionConfig for initialization

This commit is contained in:
CoCo_Sol 2024-04-08 10:40:51 +02:00
parent d8a81646a6
commit ad39446a49
2 changed files with 70 additions and 33 deletions

View file

@ -2,9 +2,7 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::{BTreeSet, LinkedList}; use std::collections::{BTreeSet, LinkedList};
use std::io;
use prost::bytes::BufMut;
use prost::Message; use prost::Message;
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;
use rand_core::SeedableRng; use rand_core::SeedableRng;
@ -16,33 +14,62 @@ use uuid::Uuid;
/// A Raft node. /// A Raft node.
pub struct RaftConnection { pub struct RaftConnection {
/// The connection to the relay server.
connection: Connection, connection: Connection,
/// The Raft node.
node: RaftNode<RaftLogMemory, ChaChaRng, Uuid>, node: RaftNode<RaftLogMemory, ChaChaRng, Uuid>,
/// The peers of the Raft cluster.
peers: BTreeSet<Uuid>, peers: BTreeSet<Uuid>,
} }
pub struct RaftConnectionConfig {
/// The minimum number of timer ticks between leadership elections.
pub election_timeout_ticks: u32,
/// The number of timer ticks between sending heartbeats to peers.
pub heartbeat_interval_ticks: u32,
/// The maximum number of bytes to replicate to a peer at a time.
pub replication_chunk_size: usize,
}
impl From<RaftConnectionConfig> for RaftConfig {
fn from(val: RaftConnectionConfig) -> Self {
RaftConfig {
election_timeout_ticks: val.election_timeout_ticks,
heartbeat_interval_ticks: val.heartbeat_interval_ticks,
replication_chunk_size: val.replication_chunk_size,
}
}
}
impl RaftConnection { impl RaftConnection {
/// Creates a new Raft connection from a current connection. /// Creates a new Raft connection from a current connection.
/// Returns an error if the connection does not have an identifier. /// Returns an error if the connection does not have an identifier.
pub fn from(connection: Connection, peers: BTreeSet<Uuid>) -> Result<Self, Connection> { pub fn from(
connection: Connection,
peers: BTreeSet<Uuid>,
raft_config: RaftConnectionConfig,
) -> Result<Self, Connection> {
let Some(identifier) = connection.identifier() else { let Some(identifier) = connection.identifier() else {
return Err(connection); return Err(connection);
}; };
Ok(Self {
let raft_node = Self {
connection, connection,
node: RaftNode::new( node: RaftNode::new(
identifier, identifier,
peers.clone(), peers.clone(),
RaftLogMemory::new_unbounded(), RaftLogMemory::new_unbounded(),
ChaChaRng::seed_from_u64(identifier.as_u64_pair().0), ChaChaRng::seed_from_u64(identifier.as_u64_pair().0),
RaftConfig { raft_config.into(),
election_timeout_ticks: 10,
heartbeat_interval_ticks: 1,
replication_chunk_size: usize::max_value(),
},
), ),
peers, peers,
}) };
Ok(raft_node)
} }
/// Envoit un message à tous les noeuds du cluster. /// Envoit un message à tous les noeuds du cluster.
@ -50,8 +77,9 @@ impl RaftConnection {
let mut data = message.into().into_owned(); let mut data = message.into().into_owned();
if self.node.is_leader() { if self.node.is_leader() {
let Ok(messages) = self.node.append(data) else { let Ok(messages) = self.node.append(data) else {
panic!("OOOOOOOOH!"); panic!("Message just cancelled.");
}; };
Self::send_raft_messages(&self.connection, &self.peers, messages);
} else { } else {
data.push(1); data.push(1);
if let (Some(leader), _) = self.node.leader() { if let (Some(leader), _) = self.node.leader() {
@ -70,24 +98,28 @@ impl RaftConnection {
sendable.message.encode(&mut data).ok(); sendable.message.encode(&mut data).ok();
data.push(0); data.push(0);
match sendable.dest { // Send the message to the target node.
RaftMessageDestination::Broadcast => { if let RaftMessageDestination::To(target) = sendable.dest {
for peer in peers connection.send(target, data);
return;
}
// Broadcast the message to all peers.
peers
.iter() .iter()
.filter(|&peer| Some(*peer) != connection.identifier()) .filter(|&peer| Some(*peer) != connection.identifier())
{ .for_each(|peer| {
connection.send(*peer, &data); connection.send(*peer, &data);
} });
}
RaftMessageDestination::To(target) => connection.send(target, data),
}
} }
} }
pub fn update(&mut self) -> LinkedList<Vec<u8>> { pub fn update(&mut self) -> LinkedList<Vec<u8>> {
// Update the Raft node.
let messages = self.node.timer_tick(); let messages = self.node.timer_tick();
Self::send_raft_messages(&self.connection, &self.peers, messages); Self::send_raft_messages(&self.connection, &self.peers, messages);
// Update the connection.
let messages = self.connection.update(); let messages = self.connection.update();
for (sender_id, mut message) in messages { for (sender_id, mut message) in messages {
let message_type = message[message.len() - 1]; let message_type = message[message.len() - 1];
@ -100,18 +132,17 @@ impl RaftConnection {
} }
1 if self.node.is_leader() => { 1 if self.node.is_leader() => {
let Ok(messages) = self.node.append(message) else { let Ok(messages) = self.node.append(message) else {
panic!("OOOOOOOOH!"); panic!("Message just cancelled.");
}; };
Self::send_raft_messages(&self.connection, &self.peers, messages); Self::send_raft_messages(&self.connection, &self.peers, messages);
} }
_ => panic!("AAAAAH!"), _ => (),
} }
} }
let mut result = LinkedList::new(); self.node
for message in self.node.take_committed() { .take_committed()
result.push_back(message.data.to_vec()); .map(|v| v.data.to_vec())
} .collect()
result
} }
} }

View file

@ -4,7 +4,7 @@ use std::time::Duration;
use std::{io, thread}; use std::{io, thread};
use relay_client::Connection; use relay_client::Connection;
use relay_raft::RaftConnection; use relay_raft::{RaftConnection, RaftConnectionConfig};
use uuid::Uuid; use uuid::Uuid;
fn main() { fn main() {
@ -24,7 +24,15 @@ fn main() {
.map(|s| Uuid::parse_str(s).unwrap()) .map(|s| Uuid::parse_str(s).unwrap())
.collect(); .collect();
let Ok(mut connection) = RaftConnection::from(connection, peers) else { let Ok(mut connection) = RaftConnection::from(
connection,
peers,
RaftConnectionConfig {
election_timeout_ticks: 10,
heartbeat_interval_ticks: 1,
replication_chunk_size: usize::max_value(),
},
) else {
panic!("Failed to create raft connection"); panic!("Failed to create raft connection");
}; };
@ -33,9 +41,7 @@ fn main() {
loop { loop {
let mut message = String::new(); let mut message = String::new();
io::stdin().read_line(&mut message).unwrap(); io::stdin().read_line(&mut message).unwrap();
sender sender.send(message.replace(['\n', '\r'], "")).unwrap();
.send(message.replace('\n', "").replace('\r', ""))
.unwrap();
} }
}); });