diff --git a/Cargo.lock b/Cargo.lock index c7c0e3b..65cc617 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,6 +99,17 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-client-ip" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e7c467bdcd2bd982ce5c8742a1a178aba7b03db399fd18f5d5d438f5aa91cb4" +dependencies = [ + "axum", + "forwarded-header-value", + "serde", +] + [[package]] name = "axum-core" version = "0.4.3" @@ -439,6 +450,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "forwarded-header-value" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" +dependencies = [ + "nonempty", + "thiserror", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -608,6 +629,7 @@ name = "hitman" version = "0.0.1" dependencies = [ "axum", + "axum-client-ip", "chrono", "clap", "dotenvy", @@ -891,6 +913,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonempty" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" + [[package]] name = "num-bigint-dig" version = "0.8.4" diff --git a/Cargo.toml b/Cargo.toml index eb2acb9..516c50b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://git.kittencollective.com/nebkor/hitman" [dependencies] axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "macros"] } +axum-client-ip = "0.5" chrono = { version = "0.4", default-features = false, features = ["now"] } clap = { version = "4.5", default-features = false, features = ["std", "derive", "unicode", "help", "usage"] } dotenvy = { version = "0.15", default-features = false } diff --git a/src/main.rs b/src/main.rs index bab6a79..5214c93 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,12 +6,12 @@ use std::{ }; use axum::{ - debug_handler, extract::{Path, State}, - http::{method::Method, HeaderMap, HeaderValue}, + http::{method::Method, HeaderValue}, routing::get, Router, }; +use axum_client_ip::InsecureClientIp; use clap::Parser; use lazy_static::lazy_static; use ring::digest::{Context, SHA256}; @@ -25,6 +25,18 @@ lazy_static! { static ref SESSION_SALT: u64 = rand::random(); } +#[derive(Debug, Parser)] +#[clap(version, about)] +struct Cli { + #[clap( + long, + short, + help = "Path to environment file.", + default_value = ".env" + )] + pub env: OsString, +} + #[tokio::main] async fn main() { init(); @@ -37,12 +49,15 @@ async fn main() { let app = Router::new() .route("/hit/:slug", get(register_hit)) - .route("/hits/:slug", get(get_hits)) - .route("/hits/:slug/:period", get(get_hits)) + .route("/hits/:slug", get(get_all_hits)) + .route("/hits/:slug/:period", get(get_period_hits)) .layer(cors_layer) .with_state(pool.clone()) - .into_make_service(); + // we need the connect info in order to extract the IP address + .into_make_service_with_connect_info::(); + let listener = mklistener().await; + axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await @@ -56,17 +71,14 @@ async fn main() { //-************************************************************************ /// This is the main handler. It counts the hit and returns the latest count. -#[debug_handler] async fn register_hit( Path(slug): Path, State(db): State, - req: HeaderMap, + InsecureClientIp(ip): InsecureClientIp, ) -> String { - let host = req - .get("host") - .cloned() - .unwrap_or(HeaderValue::from_str("").unwrap()); - let host = host.to_str().unwrap_or(""); + let slug = &slug; + + let host = ip.to_string(); let now = chrono::Utc::now(); let now = now.to_rfc3339(); @@ -74,6 +86,9 @@ async fn register_hit( // duplicate views from the same host to one per hour: let now = now.split(':').take(1).next().unwrap(); + // the salt here is regenerated every time the service restarts, and guarantees + // we can't just enumerate all the possible hashes based on IP, page, and + // time alone. let salt = *SESSION_SALT; let key = format!("{now}{host}{slug}{salt}").into_bytes(); let key = hex::encode(shasum(&key)); @@ -81,64 +96,64 @@ async fn register_hit( let tx = db.begin().await; if let Ok(mut tx) = tx { - sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,) + match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,) .execute(&mut *tx) .await - .unwrap_or_default(); - tx.commit().await.unwrap_or_default(); + { + Ok(_) => tx.commit().await.unwrap_or_default(), + _ => { /* whatevs, fine */ } + } } - let hits = sqlx::query_scalar!("select count(*) from hits where page = ?", slug) - .fetch_one(&db) - .await - .unwrap_or(1); - + let hits = all_hits_helper(&db, slug).await; format!("{hits}") } -/// fer fancy people what want to be a little finer -#[axum::debug_handler] -async fn get_hits( +async fn get_period_hits( State(db): State, Path(slug): Path, - period: Option>, + Path(period): Path, ) -> String { let now = chrono::Utc::now(); let slug = &slug; - let count = match period.unwrap_or(Path("all".to_string())).as_str() { + let when = match period.as_str() { "day" => { let then = now - chrono::Duration::try_hours(24).unwrap(); - let then = then.to_rfc3339(); - get_period_hits(&db, slug, &then).await + then.to_rfc3339() } "week" => { let then = now - chrono::Duration::try_days(7).unwrap(); - let then = then.to_rfc3339(); - get_period_hits(&db, slug, &then).await + then.to_rfc3339() } - _ => sqlx::query_scalar!("select count(*) from hits where page = ?", slug) - .fetch_one(&db) - .await - .unwrap_or(1), + _ => "all".to_string(), }; - format!("{count}") + let hits = sqlx::query_scalar!( + "select count(*) from hits where page = ? and accessed > ?", + slug, + when + ) + .fetch_one(&db) + .await + .unwrap_or(0); + + format!("{hits}") +} + +async fn get_all_hits(State(db): State, Path(slug): Path) -> String { + let hits = all_hits_helper(&db, &slug).await; + format!("{hits}") } //-************************************************************************ // li'l helpers //-************************************************************************ -#[derive(Debug, Parser)] -#[clap(version, about)] -struct Cli { - #[clap( - long, - short, - help = "Path to environment file.", - default_value = ".env" - )] - pub env: OsString, +async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 { + sqlx::query_scalar!("select count(*) from hits where page = ?", slug) + .fetch_one(db) + .await + .unwrap_or(1) } fn init() { @@ -221,14 +236,3 @@ fn shasum(input: &[u8]) -> Vec { context.update(input); context.finish().as_ref().to_vec() } - -async fn get_period_hits(db: &SqlitePool, page: &str, when: &str) -> i32 { - sqlx::query_scalar!( - "select count(*) from hits where page = ? and accessed > ?", - page, - when - ) - .fetch_one(db) - .await - .unwrap_or(0) -}