diff --git a/day03/src/main.rs b/day03/src/main.rs index 0ecf189..b565993 100644 --- a/day03/src/main.rs +++ b/day03/src/main.rs @@ -8,24 +8,65 @@ use winnow::{ fn main() { let input = std::fs::read_to_string("input").unwrap(); println!("{}", pt1(&input)); + println!("{}", pt2(&input)); } #[derive(Debug, PartialEq, Eq)] enum Inst { - Mul(i32, i32), + Mul(i64, i64), + Do, + Dont, } -fn pt1(input: &str) -> i32 { +fn pt1(input: &str) -> i64 { let mut input = input; let v = parse(&mut input).unwrap(); - v.iter().map(|Inst::Mul(a, b)| (a * b)).sum() + v.iter() + .map(|inst| match inst { + Inst::Mul(a, b) => a * b, + _ => 0, + }) + .sum() +} + +fn pt2(input: &str) -> i64 { + let mut input = input; + let instructions = parse(&mut input).unwrap(); + let mut doo = true; + let mut total = 0; + for i in &instructions { + match i { + Inst::Do => doo = true, + Inst::Dont => doo = false, + Inst::Mul(a, b) => { + if doo { + total += a * b; + } + } + } + } + total +} + +fn parse_mul(input: &mut &str) -> PResult { + let _ = "mul".parse_next(input)?; + let pair: (i64, i64) = + delimited('(', separated_pair(dec_int, ',', dec_int), ')').parse_next(input)?; + Ok(Inst::Mul(pair.0, pair.1)) +} + +fn parse_do(input: &mut &str) -> PResult { + let _ = "do()".parse_next(input)?; + Ok(Inst::Do) +} + +fn parse_dont(input: &mut &str) -> PResult { + let _ = "don't()".parse_next(input)?; + Ok(Inst::Dont) } fn parse_inst(input: &mut &str) -> PResult { - let _ = "mul".parse_next(input)?; - let pair: (i32, i32) = - delimited('(', separated_pair(dec_int, ',', dec_int), ')').parse_next(input)?; - Ok(Inst::Mul(pair.0, pair.1)) + alt((parse_mul, parse_do, parse_dont)).parse_next(input) } fn parse(input: &mut &str) -> PResult> { @@ -50,11 +91,18 @@ fn parse(input: &mut &str) -> PResult> { mod test { use super::*; - static INPUT: &str = "xmul(2,4)%&mul[3,7]!@^do_not_mul(5,5)+mul(32,64]then(mul(11,8)mul(8,5))"; + static INPUT1: &str = "xmul(2,4)%&mul[3,7]!@^do_not_mul(5,5)+mul(32,64]then(mul(11,8)mul(8,5))"; + static INPUT2: &str = + "xmul(2,4)&mul[3,7]!^don't()_mul(5,5)+mul(32,64](mul(11,8)undo()?mul(8,5))"; #[test] fn p1() { - let v = pt1(INPUT); + let v = pt1(INPUT1); assert_eq!(v, 161) } + + #[test] + fn p2() { + assert_eq!(48, pt2(INPUT2)); + } }