diff --git a/gg-broadcast/src/main.rs b/gg-broadcast/src/main.rs index f0085d7..b12e282 100644 --- a/gg-broadcast/src/main.rs +++ b/gg-broadcast/src/main.rs @@ -1,7 +1,6 @@ use std::{ collections::{HashMap, HashSet}, - io::BufRead, - sync::{mpsc::channel, Arc, Mutex}, + sync::{Arc, Mutex}, thread, time::Duration, }; @@ -10,28 +9,18 @@ use nebkor_maelstrom::{protocol::Payload, Body, Message, Node, Runner}; use rand::Rng; fn main() { - let out = std::io::stdout(); - let std_in = Arc::new(std::io::stdin()); - - let (tx, rx) = channel(); - let input = tx.clone(); - - let i = thread::spawn(move || { - let g = std_in.lock(); - for line in g.lines().map_while(Result::ok) { - if let Ok(msg) = serde_json::from_str::(&line) { - input.send(msg).unwrap(); - } - } - }); - let node = BCaster::default(); let node = Arc::new(Mutex::new(node)); - let runner = Runner::new(node, out); + let runner = Runner::new(node); let runner = &runner; - let g = thread::spawn(move || loop { + runner.run(Some(Box::new(on_init))); +} + +fn on_init(runner: &Runner) { + let tx = runner.get_input(); + thread::spawn(move || loop { let millis = rand::thread_rng().gen_range(400..=800); thread::sleep(Duration::from_millis(millis)); let body = Body::from_type("do_gossip"); @@ -41,10 +30,6 @@ fn main() { }; tx.send(msg).unwrap(); }); - - runner.run(rx, None); - let _ = i.join(); - let _ = g.join(); } #[derive(Clone, Default)] diff --git a/gg-echo/src/main.rs b/gg-echo/src/main.rs index 4105455..f781cfe 100644 --- a/gg-echo/src/main.rs +++ b/gg-echo/src/main.rs @@ -1,8 +1,4 @@ -use std::{ - io::BufRead, - sync::{mpsc::channel, Arc, Mutex}, - thread, -}; +use std::sync::{Arc, Mutex}; use nebkor_maelstrom::{Body, Message, Node, Runner}; @@ -19,26 +15,11 @@ impl Node for Echo { } fn main() { - let out = std::io::stdout(); - let std_in = Arc::new(std::io::stdin()); - - let (tx, rx) = channel(); - - let i = thread::spawn(move || { - let g = std_in.lock(); - for line in g.lines().map_while(Result::ok) { - if let Ok(msg) = serde_json::from_str::(&line) { - tx.send(msg).unwrap(); - } - } - }); - let node = Echo; let node = Arc::new(Mutex::new(node)); - let runner = Runner::new(node, out); + let runner = Runner::new(node); - runner.run(rx, None); - i.join().unwrap(); + runner.run(None); } diff --git a/gg-g_counter/src/main.rs b/gg-g_counter/src/main.rs index acd6d2d..ca4bbc0 100644 --- a/gg-g_counter/src/main.rs +++ b/gg-g_counter/src/main.rs @@ -1,17 +1,12 @@ -use std::{ - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; -use nebkor_maelstrom::{mk_payload, mk_stdin, Body, Message, Node, Runner}; +use nebkor_maelstrom::{mk_payload, Body, Message, Node, Runner}; fn main() { let node = Counter; let node = Arc::new(Mutex::new(node)); - let runner = Rc::new(Runner::new(node)); - - let (i, _, rx) = mk_stdin(); + let runner = Runner::new(node); let on_init = |rnr: &Runner| { let payload = mk_payload(&[ @@ -26,8 +21,7 @@ fn main() { let on_init = Box::new(on_init); - runner.run(rx, Some(on_init)); - i.join().unwrap(); + runner.run(Some(on_init)); } const KEY: &str = "COUNTER"; @@ -36,52 +30,28 @@ const KEY: &str = "COUNTER"; struct Counter; impl Node for Counter { - fn handle<'slf>(&'slf mut self, runner: &'slf Rc, req: Message) { + fn handle<'slf>(&'slf mut self, runner: &'slf Runner, req: Message) { let typ = req.body.typ.as_str(); let frm = req.src.clone(); let msg_id = req.body.msg_id.to_owned(); match typ { "add" => { - let read_runner = runner.clone(); - let cas_runner = runner.clone(); - let add_req = req.clone(); - - let cas_handler = move |_msg: Message| { - let req = add_req; - cas_runner.reply(&req, Body::from_type("add_ok")); - }; - let delta = req.body.payload.get("delta").unwrap().as_i64().unwrap(); - let read_handler = move |msg: Message| { - let value = msg.body.payload.get("value").unwrap().as_i64().unwrap(); - let payload = mk_payload(&[ - ("from", value.into()), - ("to", (value + delta).into()), - ("key", KEY.into()), - ]); - let body = Body::from_type("cas").with_payload(payload); - read_runner.rpc("seq-kv", body, Box::new(cas_handler)); - }; - // kick it off by calling "read" on seq-kv: - let payload = mk_payload(&[("key", KEY.into())]); - let body = Body::from_type("read").with_payload(payload); - runner.rpc("seq-kv", body, Box::new(read_handler)); } "read" => { - let rn = runner.clone(); - let h = move |msg: Message| { - let src = frm.clone(); - let value = msg.body.payload.get("value").unwrap().as_i64().unwrap(); - let irt = msg_id; - let payload = mk_payload(&[("value", value.into())]); - let body = Body::from_type("read_ok") - .with_in_reply_to(irt) - .with_payload(payload); - rn.send(&src, body); - }; let payload = mk_payload(&[("key", KEY.into())]); let body = Body::from_type("read").with_payload(payload); - runner.rpc("seq-kv", body, Box::new(h)); + let val = runner + .rpc("seq-kv", body) + .recv() + .unwrap() + .body + .payload + .get("value") + .cloned() + .unwrap(); + let body = Body::from_type("read_ok").with_payload(mk_payload(&[("value", val)])); + runner.reply(&req, body); } _ => { eprintln!("unknown type: {req:?}"); diff --git a/gg-uid/src/main.rs b/gg-uid/src/main.rs index fb8979a..fa795a4 100644 --- a/gg-uid/src/main.rs +++ b/gg-uid/src/main.rs @@ -1,33 +1,14 @@ -use std::{ - io::BufRead, - sync::{mpsc::channel, Arc, Mutex}, - thread, -}; +use std::sync::{Arc, Mutex}; use nebkor_maelstrom::{protocol::Payload, Body, Message, Node, Runner}; fn main() { - let out = std::io::stdout(); - let std_in = Arc::new(std::io::stdin()); - - let (tx, rx) = channel(); - - let i = thread::spawn(move || { - let g = std_in.lock(); - for line in g.lines().map_while(Result::ok) { - if let Ok(msg) = serde_json::from_str::(&line) { - tx.send(msg).unwrap(); - } - } - }); - let node = GenUid; let node = Arc::new(Mutex::new(node)); - let runner = Runner::new(node, out); + let runner = Runner::new(node); - runner.run(rx, None); - i.join().unwrap(); + runner.run(None); } #[derive(Clone, Default)] diff --git a/nebkor-maelstrom/src/lib.rs b/nebkor-maelstrom/src/lib.rs index 2b45ebc..06acdc6 100644 --- a/nebkor-maelstrom/src/lib.rs +++ b/nebkor-maelstrom/src/lib.rs @@ -1,13 +1,12 @@ use std::{ collections::HashMap, - io::{BufRead, Stdout, Write}, - rc::Rc, + io::{BufRead, Write}, sync::{ atomic::{AtomicU64, AtomicUsize, Ordering}, mpsc::{channel, Receiver, Sender}, - Arc, LazyLock, Mutex, OnceLock, + Arc, Mutex, OnceLock, }, - thread::{self, JoinHandle}, + thread::{self}, }; pub mod protocol; @@ -18,16 +17,16 @@ use serde_json::Value; pub mod kv; pub type DynNode = Arc>; -pub type Handler = Box; // -> Result>>; pub type OnInit = Box; pub type Result = std::result::Result; +pub type RpcPromise = Receiver; + static MSG_ID: AtomicU64 = AtomicU64::new(0); -static OUTPUT: LazyLock = LazyLock::new(std::io::stdout); pub trait Node { - fn handle(&mut self, runner: &Rc, msg: Message); + fn handle(&mut self, runner: &Runner, msg: Message); } #[derive(Clone)] @@ -36,7 +35,9 @@ pub struct Runner { node_id: OnceLock, nodes: OnceLock>, steps: Arc, - handlers: Arc>>, + promises: Arc>>>, + input: OnceLock>, + output: OnceLock>, } impl Runner { @@ -46,26 +47,36 @@ impl Runner { nodes: OnceLock::new(), node_id: OnceLock::new(), steps: Arc::new(AtomicUsize::new(0)), - handlers: Default::default(), + promises: Default::default(), + input: OnceLock::new(), + output: OnceLock::new(), } } - pub fn run(self: &Rc, input: Receiver, on_init: Option) { - for msg in input.iter() { + pub fn run(&self, on_init: Option) { + let (stdin_tx, stdin_rx) = run_stdin(); + let _ = self.input.get_or_init(|| stdin_tx); + + let (stdout_tx, stdout_rx) = channel(); + let _ = self.output.get_or_init(|| stdout_tx); + run_stdout(stdout_rx); + + for msg in stdin_rx { let typ = &msg.body.typ; if let "init" = typ.as_str() { self.init(&msg); let body = Body::from_type("init_ok"); self.reply(&msg, body); - if let Some(ref h) = on_init { - h(self); + if let Some(ref on_init) = on_init { + on_init(self); } } else { let irt = msg.body.in_reply_to; { - let mut g = self.handlers.lock().unwrap(); + let mut g = self.promises.lock().unwrap(); if let Some(h) = g.remove(&irt) { - h(msg.clone()); + h.send(msg.clone()).unwrap(); + continue; } } let mut n = self.node.lock().unwrap(); @@ -74,15 +85,21 @@ impl Runner { } } - pub fn rpc(&self, dest: &str, body: Body, handler: Handler) { + pub fn rpc(&self, dest: &str, body: Body) -> RpcPromise { let mut body = body; let msg_id = self.next_msg_id(); body.msg_id = msg_id; + let (tx, rx) = channel(); { - let mut g = self.handlers.lock().unwrap(); - g.insert(msg_id, handler); + let mut g = self.promises.lock().unwrap(); + g.insert(msg_id, tx); } self.send(dest, body); + rx + } + + pub fn get_input(&self) -> Sender { + self.input.get().cloned().unwrap() } pub fn node_id(&self) -> String { @@ -124,10 +141,10 @@ impl Runner { .unwrap() .iter() .map(|s| s.as_str().unwrap().to_string()) - .collect::>(); + .collect(); - let _ = self.node_id.get_or_init(|| node_id.to_owned()); - let _ = self.nodes.get_or_init(|| nodes.to_vec()); + let _ = self.node_id.get_or_init(|| node_id); + let _ = self.nodes.get_or_init(|| nodes); } pub fn reply(&self, req: &Message, body: Body) { @@ -148,27 +165,19 @@ impl Runner { dest: dest.to_string(), body, }; - let msg = serde_json::to_string(&msg).unwrap(); - self.writeln(&msg); - } - - fn writeln(&self, msg: &str) { - let mut out = OUTPUT.lock(); - let msg = format!("{msg}\n"); - out.write_all(msg.as_bytes()).unwrap(); - out.flush().unwrap(); + self.output.get().unwrap().send(msg).unwrap(); } } /// Feeds lines from stdin to the MPSC Sender, so that the Receiver can be used /// in the Runner::run() method. Clone the Sender if you want to inject messages /// into the Runner. Join the handle after `run()`. -pub fn mk_stdin() -> (JoinHandle<()>, Sender, Receiver) { +pub fn run_stdin() -> (Sender, Receiver) { let stdin = std::io::stdin(); let (tx, rx) = channel(); let xtra_input = tx.clone(); - let join = thread::spawn(move || { + thread::spawn(move || { let g = stdin.lock(); for line in g.lines().map_while(std::result::Result::ok) { if let Ok(msg) = serde_json::from_str(&line) { @@ -177,7 +186,16 @@ pub fn mk_stdin() -> (JoinHandle<()>, Sender, Receiver) { } }); - (join, xtra_input, rx) + (xtra_input, rx) +} + +fn run_stdout(rx: Receiver) { + thread::spawn(move || { + let mut stdout = std::io::stdout().lock(); + while let Ok(msg) = rx.recv() { + writeln!(stdout, "{}", serde_json::to_string(&msg).unwrap()).unwrap(); + } + }); } pub fn mk_payload(payload: &[(&str, Value)]) -> Payload {