229 lines
7 KiB
Rust
229 lines
7 KiB
Rust
use std::{
|
|
collections::HashMap,
|
|
io::{BufRead, Write},
|
|
sync::{
|
|
atomic::{AtomicU64, Ordering},
|
|
mpsc::{channel, Receiver, Sender},
|
|
Arc, Mutex, OnceLock,
|
|
},
|
|
thread::{self},
|
|
};
|
|
|
|
pub use serde_json::Value;
|
|
|
|
pub mod protocol;
|
|
pub use protocol::{Body, ErrorCode, Message, Payload};
|
|
|
|
pub mod kv;
|
|
|
|
pub type NodeyNodeFace = Arc<Mutex<dyn Node>>;
|
|
pub type OnInit = Box<dyn Fn(&Runner)>;
|
|
pub type RpcPromise = Receiver<Message>;
|
|
pub type RpcResult = std::result::Result<Option<Value>, ErrorCode>;
|
|
|
|
pub trait Node {
|
|
fn handle(&mut self, runner: &Runner, msg: Message);
|
|
}
|
|
|
|
pub struct Runner {
|
|
node: NodeyNodeFace,
|
|
node_id: OnceLock<String>,
|
|
nodes: OnceLock<Vec<String>>,
|
|
backdoor: OnceLock<Sender<Message>>,
|
|
promises: Arc<Mutex<HashMap<u64, Sender<Message>>>>,
|
|
outbound_tx: OnceLock<Sender<Message>>,
|
|
msg_id: AtomicU64,
|
|
}
|
|
|
|
impl Runner {
|
|
pub fn new<N: Node + 'static>(node: N) -> Self {
|
|
let node = Arc::new(Mutex::new(node));
|
|
Runner {
|
|
node,
|
|
nodes: OnceLock::new(),
|
|
node_id: OnceLock::new(),
|
|
backdoor: OnceLock::new(),
|
|
outbound_tx: OnceLock::new(),
|
|
promises: Default::default(),
|
|
msg_id: AtomicU64::new(1),
|
|
}
|
|
}
|
|
|
|
/// Start processing messages from stdin and sending them to your node. The `on_init` argument
|
|
/// is an optional callback that will be called with `&self` after the `init` message from
|
|
/// Maelstrom has been processed.
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```no_run
|
|
/// use nebkor_maelstrom::{Body, Message, Node, Runner};
|
|
/// struct Foo;
|
|
/// impl Node for Foo {fn handle(&mut self, _runner: &Runner, _msg: Message) { /* empty impl */ }}
|
|
///
|
|
/// let runner = Runner::new(Foo);
|
|
///
|
|
/// let on_init = |rnr: &Runner| {
|
|
/// eprintln!("received the `init` message!");
|
|
/// let msg = Message { body: Body::from_type("yo_yo_yo"), ..Default::default() };
|
|
/// // send `msg` to the node to be processed by its `handle()` method
|
|
/// rnr.get_backdoor().send(msg).unwrap();
|
|
/// };
|
|
/// let on_init = Box::new(on_init);
|
|
///
|
|
/// runner.run(Some(on_init));
|
|
/// ```
|
|
pub fn run(&self, on_init: Option<OnInit>) {
|
|
let (outbound_tx, outbound_rx) = channel();
|
|
let _ = self.outbound_tx.get_or_init(|| outbound_tx);
|
|
|
|
// decouple processing output from handling messages
|
|
thread::spawn(move || {
|
|
let mut stdout = std::io::stdout().lock();
|
|
while let Ok(msg) = outbound_rx.recv() {
|
|
let msg = serde_json::to_string(&msg).unwrap();
|
|
writeln!(&mut stdout, "{msg}").unwrap();
|
|
}
|
|
});
|
|
|
|
self.process_input(on_init);
|
|
}
|
|
|
|
/// Get a Sender that will send Messages to the node as input. Useful for triggering periodic
|
|
/// behavior from a separate thread, or for sending a Message to the node from `on_init`. See
|
|
/// the `broadcast` example for a use of it.
|
|
pub fn get_backdoor(&self) -> Sender<Message> {
|
|
self.backdoor.get().unwrap().clone()
|
|
}
|
|
|
|
pub fn node_id(&self) -> &str {
|
|
self.node_id.get().unwrap()
|
|
}
|
|
|
|
pub fn next_msg_id(&self) -> u64 {
|
|
self.msg_id.fetch_add(1, Ordering::SeqCst)
|
|
}
|
|
|
|
/// A list of all nodes in the network, including this one.
|
|
pub fn nodes(&self) -> &[String] {
|
|
self.nodes.get().unwrap()
|
|
}
|
|
|
|
/// Construct a new `Message` from `body` and send it to `req.src`.
|
|
pub fn reply(&self, req: &Message, body: Body) {
|
|
let mut body = body;
|
|
let dest = req.src.as_str();
|
|
let in_reply_to = req.body.msg_id;
|
|
body.in_reply_to = in_reply_to;
|
|
self.send(dest, body);
|
|
}
|
|
|
|
/// Construct a new `Message` from `body` and send it to `dest`.
|
|
pub fn send(&self, dest: &str, body: Body) {
|
|
let msg = self.mk_msg(dest, body);
|
|
self.outbound_tx.get().unwrap().send(msg).unwrap();
|
|
}
|
|
|
|
/// Returns a Receiver<Message> that will receive the reply from the request.
|
|
pub fn rpc(&self, dest: &str, body: Body) -> RpcPromise {
|
|
let msg = self.mk_msg(dest, body);
|
|
let (tx, rx) = channel();
|
|
{
|
|
let msg_id = msg.body.msg_id;
|
|
let mut g = self.promises.lock().unwrap();
|
|
g.insert(msg_id, tx);
|
|
}
|
|
self.outbound_tx.get().unwrap().send(msg).unwrap();
|
|
rx
|
|
}
|
|
|
|
// internal methods
|
|
fn init(&self, msg: &Message) {
|
|
let node_id = msg
|
|
.body
|
|
.payload
|
|
.get("node_id")
|
|
.unwrap()
|
|
.as_str()
|
|
.unwrap()
|
|
.to_owned();
|
|
|
|
let nodes = msg
|
|
.body
|
|
.payload
|
|
.get("node_ids")
|
|
.unwrap()
|
|
.as_array()
|
|
.unwrap()
|
|
.iter()
|
|
.map(|s| s.as_str().unwrap().to_string())
|
|
.collect();
|
|
|
|
let _ = self.node_id.get_or_init(|| node_id);
|
|
let _ = self.nodes.get_or_init(|| nodes);
|
|
}
|
|
|
|
fn process_input(&self, on_init: Option<OnInit>) {
|
|
// for sending Messages to the node's inputs
|
|
let (encoded_input_tx, encoded_input_rx) = channel();
|
|
let _ = self.backdoor.get_or_init(|| encoded_input_tx.clone());
|
|
|
|
// decouple stdin from processing
|
|
let proms = self.promises.clone();
|
|
thread::spawn(move || {
|
|
let stdin = std::io::stdin().lock();
|
|
for line in stdin.lines().map_while(std::result::Result::ok) {
|
|
let msg: Message = serde_json::from_str(&line).unwrap();
|
|
let irt = msg.body.in_reply_to;
|
|
if let Some(promise) = proms.lock().unwrap().remove(&irt) {
|
|
// this is the result of an RPC call
|
|
promise.send(msg).unwrap();
|
|
} else {
|
|
// just let the node's `handle()` method handle it
|
|
encoded_input_tx.send(msg).unwrap();
|
|
}
|
|
}
|
|
});
|
|
|
|
// first Message is always `init`:
|
|
let msg = encoded_input_rx.recv().unwrap();
|
|
{
|
|
self.init(&msg);
|
|
let body = Body::from_type("init_ok");
|
|
self.reply(&msg, body);
|
|
if let Some(on_init) = on_init {
|
|
on_init(self);
|
|
}
|
|
}
|
|
|
|
// every other message is for the node's handle() method
|
|
let mut node = self.node.lock().unwrap();
|
|
for msg in encoded_input_rx {
|
|
node.handle(self, msg);
|
|
}
|
|
}
|
|
|
|
fn mk_msg(&self, dest: &str, body: Body) -> Message {
|
|
let mut body = body;
|
|
if body.msg_id == 0 {
|
|
body.msg_id = self.next_msg_id();
|
|
}
|
|
Message::from_dest(dest)
|
|
.with_body(body)
|
|
.with_src(self.node_id())
|
|
}
|
|
}
|
|
|
|
pub fn check_err(msg: &Message) -> RpcResult {
|
|
if msg.body.typ.as_str() == "error" {
|
|
let error = msg.body.code.unwrap();
|
|
return Err(error);
|
|
}
|
|
Ok(None)
|
|
}
|
|
|
|
pub fn mk_payload(payload: &[(&str, Value)]) -> Payload {
|
|
payload
|
|
.iter()
|
|
.map(|p| (p.0.to_string(), p.1.clone()))
|
|
.collect()
|
|
}
|