diff --git a/gg-g_counter/src/main.rs b/gg-g_counter/src/main.rs index 0505f8b..2d9922b 100644 --- a/gg-g_counter/src/main.rs +++ b/gg-g_counter/src/main.rs @@ -21,14 +21,14 @@ fn main() { } impl Node for Counter { - fn handle<'slf>(&'slf mut self, runner: &'slf Runner, req: Message) { + fn handle(&mut self, runner: &Runner, req: Message) { let typ = req.body.typ.as_str(); let kv = Kv::seq(); match typ { "add" => { let delta = req.body.payload.get("delta").unwrap().as_i64().unwrap(); loop { - let cur = kv.read(runner, KEY).unwrap().as_i64().unwrap(); + let cur = kv.read(runner, KEY).unwrap().unwrap().as_i64().unwrap(); match kv.cas(runner, KEY, cur.into(), (cur + delta).into(), false) { Err(_) => {} Ok(_) => break, @@ -37,7 +37,7 @@ impl Node for Counter { runner.reply(&req, Body::from_type("add_ok")); } "read" => { - let val = kv.read(runner, KEY).unwrap(); + let val = kv.read(runner, KEY).unwrap().unwrap(); let body = Body::from_type("read_ok").with_payload(mk_payload(&[("value", val)])); runner.reply(&req, body); } diff --git a/nebkor-maelstrom/src/kv.rs b/nebkor-maelstrom/src/kv.rs index 4ad7bed..dd841e6 100644 --- a/nebkor-maelstrom/src/kv.rs +++ b/nebkor-maelstrom/src/kv.rs @@ -1,8 +1,6 @@ use serde_json::Value; -use crate::{check_err, mk_payload, protocol::ErrorCode, Body, RpcResult, Runner}; - -pub type ReadResult = std::result::Result; +use crate::{check_err, mk_payload, Body, RpcResult, Runner}; #[derive(Debug, Default, Clone)] pub struct Kv { @@ -22,13 +20,13 @@ impl Kv { Kv { service: "lww-kv" } } - pub fn read(&self, runner: &Runner, key: &str) -> ReadResult { + pub fn read(&self, runner: &Runner, key: &str) -> RpcResult { let payload = mk_payload(&[("key", key.into())]); let body = Body::from_type("read").with_payload(payload); let rx = runner.rpc(self.service, body); let msg = rx.recv().unwrap(); check_err(&msg)?; - Ok(msg.body.payload.get("value").unwrap().to_owned()) + Ok(Some(msg.body.payload.get("value").unwrap().to_owned())) } pub fn write(&self, runner: &Runner, key: &str, val: Value) -> RpcResult { @@ -36,7 +34,7 @@ impl Kv { let body = Body::from_type("write").with_payload(payload); let msg = runner.rpc(self.service, body).recv().unwrap(); check_err(&msg)?; - Ok(()) + Ok(None) } pub fn cas( @@ -56,6 +54,6 @@ impl Kv { let body = Body::from_type("cas").with_payload(payload); let msg = runner.rpc(self.service, body).recv().unwrap(); check_err(&msg)?; - Ok(()) + Ok(None) } } diff --git a/nebkor-maelstrom/src/lib.rs b/nebkor-maelstrom/src/lib.rs index 0c7c286..7f048b4 100644 --- a/nebkor-maelstrom/src/lib.rs +++ b/nebkor-maelstrom/src/lib.rs @@ -9,17 +9,17 @@ use std::{ thread::{self}, }; -pub mod protocol; -use protocol::ErrorCode; -pub use protocol::{Body, Message, Payload}; use serde_json::Value; +pub mod protocol; +pub use protocol::{Body, ErrorCode, Message, Payload}; + pub mod kv; pub type NodeyNodeFace = Arc>; pub type OnInit = Box; pub type RpcPromise = Receiver; -pub type RpcResult = std::result::Result<(), ErrorCode>; +pub type RpcResult = std::result::Result, ErrorCode>; pub trait Node { fn handle(&mut self, runner: &Runner, msg: Message); @@ -71,46 +71,6 @@ impl Runner { self.run_input(stdin_rx, on_init); } - fn run_input(&self, stdin_rx: Receiver, on_init: Option) { - let (json_tx, json_rx) = channel(); - let _ = self.backdoor.get_or_init(|| json_tx.clone()); - let proms = self.promises.clone(); - thread::spawn(move || { - for line in stdin_rx { - let msg: Message = serde_json::from_str(&line).unwrap(); - let irt = msg.body.in_reply_to; - if let Some(tx) = proms.lock().unwrap().remove(&irt) { - tx.send(msg).unwrap(); - } else { - json_tx.send(msg).unwrap(); - } - } - }); - - for msg in json_rx { - if msg.body.typ.as_str() == "init" { - self.init(&msg); - let body = Body::from_type("init_ok"); - self.reply(&msg, body); - if let Some(ref on_init) = on_init { - on_init(self); - } - } else { - let mut node = self.node.lock().unwrap(); - node.handle(self, msg); - } - } - } - - fn run_output(&self, stdout_tx: Sender, node_output_rx: Receiver) { - thread::spawn(move || { - while let Ok(msg) = node_output_rx.recv() { - let msg = serde_json::to_string(&msg).unwrap(); - stdout_tx.send(msg).unwrap(); - } - }); - } - pub fn get_backdoor(&self) -> Sender { self.backdoor.get().unwrap().clone() } @@ -131,31 +91,6 @@ impl Runner { self.nodes.get().unwrap() } - pub fn init(&self, msg: &Message) { - let node_id = msg - .body - .payload - .get("node_id") - .unwrap() - .as_str() - .unwrap() - .to_owned(); - - let nodes = msg - .body - .payload - .get("node_ids") - .unwrap() - .as_array() - .unwrap() - .iter() - .map(|s| s.as_str().unwrap().to_string()) - .collect(); - - let _ = self.node_id.get_or_init(|| node_id); - let _ = self.nodes.get_or_init(|| nodes); - } - pub fn reply(&self, req: &Message, body: Body) { let mut body = body; let dest = req.src.as_str(); @@ -196,6 +131,72 @@ impl Runner { self.outbound_tx.get().unwrap().send(msg).unwrap(); rx } + + // internal methods + fn init(&self, msg: &Message) { + let node_id = msg + .body + .payload + .get("node_id") + .unwrap() + .as_str() + .unwrap() + .to_owned(); + + let nodes = msg + .body + .payload + .get("node_ids") + .unwrap() + .as_array() + .unwrap() + .iter() + .map(|s| s.as_str().unwrap().to_string()) + .collect(); + + let _ = self.node_id.get_or_init(|| node_id); + let _ = self.nodes.get_or_init(|| nodes); + } + + fn run_input(&self, stdin_rx: Receiver, on_init: Option) { + let (json_tx, json_rx) = channel(); + let _ = self.backdoor.get_or_init(|| json_tx.clone()); + let proms = self.promises.clone(); + thread::spawn(move || { + for line in stdin_rx { + let msg: Message = serde_json::from_str(&line).unwrap(); + let irt = msg.body.in_reply_to; + if let Some(tx) = proms.lock().unwrap().remove(&irt) { + tx.send(msg).unwrap(); + } else { + json_tx.send(msg).unwrap(); + } + } + }); + + for msg in json_rx { + if msg.body.typ.as_str() == "init" { + self.init(&msg); + let body = Body::from_type("init_ok"); + self.reply(&msg, body); + if let Some(ref on_init) = on_init { + on_init(self); + } + } else { + let mut node = self.node.lock().unwrap(); + node.handle(self, msg); + } + } + } + + fn run_output(&self, stdout_tx: Sender, node_output_rx: Receiver) { + thread::spawn(move || { + while let Ok(msg) = node_output_rx.recv() { + let msg = serde_json::to_string(&msg).unwrap(); + stdout_tx.send(msg).unwrap(); + } + }); + } } pub fn check_err(msg: &Message) -> RpcResult { @@ -203,7 +204,7 @@ pub fn check_err(msg: &Message) -> RpcResult { let error = msg.body.code.unwrap(); return Err(error); } - Ok(()) + Ok(None) } pub fn mk_payload(payload: &[(&str, Value)]) -> Payload {