diff --git a/gg-echo/src/main.rs b/gg-echo/src/main.rs index 18c8ff9..de56082 100644 --- a/gg-echo/src/main.rs +++ b/gg-echo/src/main.rs @@ -1,49 +1,132 @@ use maelstrom_protocol as proto; use proto::{Body, Message}; -use std::io::{BufRead, Write}; +use std::{ + cell::OnceCell, + io::{BufRead, StdinLock, StdoutLock, Write}, + sync::Mutex, +}; fn main() { - let mut out = std::io::stdout().lock(); + let out = std::io::stdout().lock(); let input = std::io::stdin().lock(); - let mut mid = 1; - - let mut node_id: String = "".to_string(); - - for line in input.lines() { - match line { - Ok(line) => { - if let Ok(msg) = serde_json::from_str::(&line) { - let typ = &msg.body.typ; - match typ.as_str() { - "init" => { - let body = proto::init_ok(mid, msg.body.msg_id); - mid += 1; - let resp = message(&msg.src, &msg.dest, body); - node_id = msg.dest; - let resp = serde_json::to_string(&resp).unwrap(); - out.write_all(resp.as_bytes()).unwrap(); - out.write_all("\n".as_bytes()).unwrap(); - let _ = &mut out.flush().unwrap(); - } - "echo" => { - let body = Body::from_type("echo_ok") - .with_msg_id(mid) - .with_in_reply_to(msg.body.msg_id) - .with_payload(msg.body.payload); - mid += 1; - let resp = message(&msg.src, &node_id, body); - let resp = serde_json::to_string(&resp).unwrap(); - out.write_all(resp.as_bytes()).unwrap(); - out.write_all("\n".as_bytes()).unwrap(); - let _ = &mut out.flush().unwrap(); - } - _ => {} - } - } + let runner = &Runner::new(out); + let handler = |msg: &Message| { + let typ = &msg.body.typ; + match typ.as_str() { + "echo" => { + let body = Body::from_type("echo_ok").with_payload(msg.body.payload.clone()); + runner.reply(&msg, body); } _ => {} } + }; + runner.run(input, &handler); +} + +struct Runner<'io> { + msg_id: Mutex, + node_id: OnceCell, + nodes: OnceCell>, + output: Mutex>, +} + +impl<'io> Runner<'io> { + pub fn new(output: StdoutLock<'io>) -> Self { + Runner { + output: Mutex::new(output), + msg_id: Mutex::new(1), + nodes: OnceCell::new(), + node_id: OnceCell::new(), + } + } + + pub fn run(&self, input: StdinLock, handler: &dyn Fn(&Message)) { + for line in input.lines() { + match line { + Ok(line) => { + if let Ok(msg) = serde_json::from_str::(&line) { + let typ = &msg.body.typ; + match typ.as_str() { + "init" => { + self.init(&msg); + + let body = Body::from_type("init_ok"); + self.reply(&msg, body); + } + _ => { + handler(&msg); + } + } + } + } + _ => {} + } + } + } + + pub fn node_id(&self) -> String { + self.node_id.get().cloned().unwrap_or("".to_string()) + } + + pub fn msg_id(&self) -> u64 { + *self.msg_id.lock().unwrap() + } + + pub fn init(&self, msg: &Message) { + let node_id = msg + .body + .payload + .get("node_id") + .unwrap() + .as_str() + .unwrap() + .to_owned(); + let nodes = msg + .body + .payload + .get("node_ids") + .unwrap() + .as_array() + .unwrap() + .iter() + .map(|s| s.as_str().unwrap().to_string()) + .collect::>(); + + let _ = self.node_id.get_or_init(|| node_id.to_owned()); + let _ = self.nodes.get_or_init(|| nodes.to_vec()); + } + + pub fn reply(&self, req: &Message, body: Body) { + let mut body = body; + let src = self.node_id.get().unwrap().to_owned(); + let dest = req.src.clone(); + let in_reply_to = req.body.msg_id; + body.in_reply_to = in_reply_to; + let msg = Message { src, dest, body }; + self.send(msg); + } + + pub fn send(&self, msg: Message) { + let mut msg = msg; + if msg.body.msg_id == 0 { + let mid = { + let mut g = self.msg_id.lock().unwrap(); + let m = *g; + *g += 1; + m + }; + msg.body.msg_id = mid; + } + let msg = serde_json::to_string(&msg).unwrap(); + let msg = format!("{msg}\n"); + self.writeln(&msg); + } + + fn writeln(&self, msg: &str) { + let mut out = self.output.lock().unwrap(); + out.write_all(msg.as_bytes()).unwrap(); + out.flush().unwrap(); } }