really use remote IP for hashing

This commit is contained in:
Joe Ardent 2024-03-28 23:32:07 -07:00
parent e96a993f9a
commit f476d6a117
3 changed files with 88 additions and 55 deletions

28
Cargo.lock generated
View file

@ -99,6 +99,17 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-client-ip"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e7c467bdcd2bd982ce5c8742a1a178aba7b03db399fd18f5d5d438f5aa91cb4"
dependencies = [
"axum",
"forwarded-header-value",
"serde",
]
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.4.3" version = "0.4.3"
@ -439,6 +450,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "forwarded-header-value"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
dependencies = [
"nonempty",
"thiserror",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.30" version = "0.3.30"
@ -608,6 +629,7 @@ name = "hitman"
version = "0.0.1" version = "0.0.1"
dependencies = [ dependencies = [
"axum", "axum",
"axum-client-ip",
"chrono", "chrono",
"clap", "clap",
"dotenvy", "dotenvy",
@ -891,6 +913,12 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nonempty"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
[[package]] [[package]]
name = "num-bigint-dig" name = "num-bigint-dig"
version = "0.8.4" version = "0.8.4"

View file

@ -11,6 +11,7 @@ repository = "https://git.kittencollective.com/nebkor/hitman"
[dependencies] [dependencies]
axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "macros"] } axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "macros"] }
axum-client-ip = "0.5"
chrono = { version = "0.4", default-features = false, features = ["now"] } chrono = { version = "0.4", default-features = false, features = ["now"] }
clap = { version = "4.5", default-features = false, features = ["std", "derive", "unicode", "help", "usage"] } clap = { version = "4.5", default-features = false, features = ["std", "derive", "unicode", "help", "usage"] }
dotenvy = { version = "0.15", default-features = false } dotenvy = { version = "0.15", default-features = false }

View file

@ -6,12 +6,12 @@ use std::{
}; };
use axum::{ use axum::{
debug_handler,
extract::{Path, State}, extract::{Path, State},
http::{method::Method, HeaderMap, HeaderValue}, http::{method::Method, HeaderValue},
routing::get, routing::get,
Router, Router,
}; };
use axum_client_ip::InsecureClientIp;
use clap::Parser; use clap::Parser;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use ring::digest::{Context, SHA256}; use ring::digest::{Context, SHA256};
@ -25,6 +25,18 @@ lazy_static! {
static ref SESSION_SALT: u64 = rand::random(); static ref SESSION_SALT: u64 = rand::random();
} }
#[derive(Debug, Parser)]
#[clap(version, about)]
struct Cli {
#[clap(
long,
short,
help = "Path to environment file.",
default_value = ".env"
)]
pub env: OsString,
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
init(); init();
@ -37,12 +49,15 @@ async fn main() {
let app = Router::new() let app = Router::new()
.route("/hit/:slug", get(register_hit)) .route("/hit/:slug", get(register_hit))
.route("/hits/:slug", get(get_hits)) .route("/hits/:slug", get(get_all_hits))
.route("/hits/:slug/:period", get(get_hits)) .route("/hits/:slug/:period", get(get_period_hits))
.layer(cors_layer) .layer(cors_layer)
.with_state(pool.clone()) .with_state(pool.clone())
.into_make_service(); // we need the connect info in order to extract the IP address
.into_make_service_with_connect_info::<SocketAddr>();
let listener = mklistener().await; let listener = mklistener().await;
axum::serve(listener, app) axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await
@ -56,17 +71,14 @@ async fn main() {
//-************************************************************************ //-************************************************************************
/// This is the main handler. It counts the hit and returns the latest count. /// This is the main handler. It counts the hit and returns the latest count.
#[debug_handler]
async fn register_hit( async fn register_hit(
Path(slug): Path<String>, Path(slug): Path<String>,
State(db): State<SqlitePool>, State(db): State<SqlitePool>,
req: HeaderMap, InsecureClientIp(ip): InsecureClientIp,
) -> String { ) -> String {
let host = req let slug = &slug;
.get("host")
.cloned() let host = ip.to_string();
.unwrap_or(HeaderValue::from_str("").unwrap());
let host = host.to_str().unwrap_or("");
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let now = now.to_rfc3339(); let now = now.to_rfc3339();
@ -74,6 +86,9 @@ async fn register_hit(
// duplicate views from the same host to one per hour: // duplicate views from the same host to one per hour:
let now = now.split(':').take(1).next().unwrap(); let now = now.split(':').take(1).next().unwrap();
// the salt here is regenerated every time the service restarts, and guarantees
// we can't just enumerate all the possible hashes based on IP, page, and
// time alone.
let salt = *SESSION_SALT; let salt = *SESSION_SALT;
let key = format!("{now}{host}{slug}{salt}").into_bytes(); let key = format!("{now}{host}{slug}{salt}").into_bytes();
let key = hex::encode(shasum(&key)); let key = hex::encode(shasum(&key));
@ -81,64 +96,64 @@ async fn register_hit(
let tx = db.begin().await; let tx = db.begin().await;
if let Ok(mut tx) = tx { if let Ok(mut tx) = tx {
sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,) match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
.execute(&mut *tx) .execute(&mut *tx)
.await .await
.unwrap_or_default(); {
tx.commit().await.unwrap_or_default(); Ok(_) => tx.commit().await.unwrap_or_default(),
_ => { /* whatevs, fine */ }
}
} }
let hits = sqlx::query_scalar!("select count(*) from hits where page = ?", slug) let hits = all_hits_helper(&db, slug).await;
.fetch_one(&db)
.await
.unwrap_or(1);
format!("{hits}") format!("{hits}")
} }
/// fer fancy people what want to be a little finer async fn get_period_hits(
#[axum::debug_handler]
async fn get_hits(
State(db): State<SqlitePool>, State(db): State<SqlitePool>,
Path(slug): Path<String>, Path(slug): Path<String>,
period: Option<Path<String>>, Path(period): Path<String>,
) -> String { ) -> String {
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let slug = &slug; let slug = &slug;
let count = match period.unwrap_or(Path("all".to_string())).as_str() { let when = match period.as_str() {
"day" => { "day" => {
let then = now - chrono::Duration::try_hours(24).unwrap(); let then = now - chrono::Duration::try_hours(24).unwrap();
let then = then.to_rfc3339(); then.to_rfc3339()
get_period_hits(&db, slug, &then).await
} }
"week" => { "week" => {
let then = now - chrono::Duration::try_days(7).unwrap(); let then = now - chrono::Duration::try_days(7).unwrap();
let then = then.to_rfc3339(); then.to_rfc3339()
get_period_hits(&db, slug, &then).await
} }
_ => sqlx::query_scalar!("select count(*) from hits where page = ?", slug) _ => "all".to_string(),
.fetch_one(&db)
.await
.unwrap_or(1),
}; };
format!("{count}") let hits = sqlx::query_scalar!(
"select count(*) from hits where page = ? and accessed > ?",
slug,
when
)
.fetch_one(&db)
.await
.unwrap_or(0);
format!("{hits}")
}
async fn get_all_hits(State(db): State<SqlitePool>, Path(slug): Path<String>) -> String {
let hits = all_hits_helper(&db, &slug).await;
format!("{hits}")
} }
//-************************************************************************ //-************************************************************************
// li'l helpers // li'l helpers
//-************************************************************************ //-************************************************************************
#[derive(Debug, Parser)] async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 {
#[clap(version, about)] sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
struct Cli { .fetch_one(db)
#[clap( .await
long, .unwrap_or(1)
short,
help = "Path to environment file.",
default_value = ".env"
)]
pub env: OsString,
} }
fn init() { fn init() {
@ -221,14 +236,3 @@ fn shasum(input: &[u8]) -> Vec<u8> {
context.update(input); context.update(input);
context.finish().as_ref().to_vec() context.finish().as_ref().to_vec()
} }
async fn get_period_hits(db: &SqlitePool, page: &str, when: &str) -> i32 {
sqlx::query_scalar!(
"select count(*) from hits where page = ? and accessed > ?",
page,
when
)
.fetch_one(db)
.await
.unwrap_or(0)
}