really use remote IP for hashing

This commit is contained in:
Joe Ardent 2024-03-28 23:32:07 -07:00
parent e96a993f9a
commit f476d6a117
3 changed files with 88 additions and 55 deletions

28
Cargo.lock generated
View file

@ -99,6 +99,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum-client-ip"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e7c467bdcd2bd982ce5c8742a1a178aba7b03db399fd18f5d5d438f5aa91cb4"
dependencies = [
"axum",
"forwarded-header-value",
"serde",
]
[[package]]
name = "axum-core"
version = "0.4.3"
@ -439,6 +450,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "forwarded-header-value"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
dependencies = [
"nonempty",
"thiserror",
]
[[package]]
name = "futures-channel"
version = "0.3.30"
@ -608,6 +629,7 @@ name = "hitman"
version = "0.0.1"
dependencies = [
"axum",
"axum-client-ip",
"chrono",
"clap",
"dotenvy",
@ -891,6 +913,12 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonempty"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
[[package]]
name = "num-bigint-dig"
version = "0.8.4"

View file

@ -11,6 +11,7 @@ repository = "https://git.kittencollective.com/nebkor/hitman"
[dependencies]
axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "macros"] }
axum-client-ip = "0.5"
chrono = { version = "0.4", default-features = false, features = ["now"] }
clap = { version = "4.5", default-features = false, features = ["std", "derive", "unicode", "help", "usage"] }
dotenvy = { version = "0.15", default-features = false }

View file

@ -6,12 +6,12 @@ use std::{
};
use axum::{
debug_handler,
extract::{Path, State},
http::{method::Method, HeaderMap, HeaderValue},
http::{method::Method, HeaderValue},
routing::get,
Router,
};
use axum_client_ip::InsecureClientIp;
use clap::Parser;
use lazy_static::lazy_static;
use ring::digest::{Context, SHA256};
@ -25,6 +25,18 @@ lazy_static! {
static ref SESSION_SALT: u64 = rand::random();
}
#[derive(Debug, Parser)]
#[clap(version, about)]
struct Cli {
#[clap(
long,
short,
help = "Path to environment file.",
default_value = ".env"
)]
pub env: OsString,
}
#[tokio::main]
async fn main() {
init();
@ -37,12 +49,15 @@ async fn main() {
let app = Router::new()
.route("/hit/:slug", get(register_hit))
.route("/hits/:slug", get(get_hits))
.route("/hits/:slug/:period", get(get_hits))
.route("/hits/:slug", get(get_all_hits))
.route("/hits/:slug/:period", get(get_period_hits))
.layer(cors_layer)
.with_state(pool.clone())
.into_make_service();
// we need the connect info in order to extract the IP address
.into_make_service_with_connect_info::<SocketAddr>();
let listener = mklistener().await;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
@ -56,17 +71,14 @@ async fn main() {
//-************************************************************************
/// This is the main handler. It counts the hit and returns the latest count.
#[debug_handler]
async fn register_hit(
Path(slug): Path<String>,
State(db): State<SqlitePool>,
req: HeaderMap,
InsecureClientIp(ip): InsecureClientIp,
) -> String {
let host = req
.get("host")
.cloned()
.unwrap_or(HeaderValue::from_str("").unwrap());
let host = host.to_str().unwrap_or("");
let slug = &slug;
let host = ip.to_string();
let now = chrono::Utc::now();
let now = now.to_rfc3339();
@ -74,6 +86,9 @@ async fn register_hit(
// duplicate views from the same host to one per hour:
let now = now.split(':').take(1).next().unwrap();
// the salt here is regenerated every time the service restarts, and guarantees
// we can't just enumerate all the possible hashes based on IP, page, and
// time alone.
let salt = *SESSION_SALT;
let key = format!("{now}{host}{slug}{salt}").into_bytes();
let key = hex::encode(shasum(&key));
@ -81,64 +96,64 @@ async fn register_hit(
let tx = db.begin().await;
if let Ok(mut tx) = tx {
sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
.execute(&mut *tx)
.await
.unwrap_or_default();
tx.commit().await.unwrap_or_default();
{
Ok(_) => tx.commit().await.unwrap_or_default(),
_ => { /* whatevs, fine */ }
}
}
let hits = sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
.fetch_one(&db)
.await
.unwrap_or(1);
let hits = all_hits_helper(&db, slug).await;
format!("{hits}")
}
/// fer fancy people what want to be a little finer
#[axum::debug_handler]
async fn get_hits(
async fn get_period_hits(
State(db): State<SqlitePool>,
Path(slug): Path<String>,
period: Option<Path<String>>,
Path(period): Path<String>,
) -> String {
let now = chrono::Utc::now();
let slug = &slug;
let count = match period.unwrap_or(Path("all".to_string())).as_str() {
let when = match period.as_str() {
"day" => {
let then = now - chrono::Duration::try_hours(24).unwrap();
let then = then.to_rfc3339();
get_period_hits(&db, slug, &then).await
then.to_rfc3339()
}
"week" => {
let then = now - chrono::Duration::try_days(7).unwrap();
let then = then.to_rfc3339();
get_period_hits(&db, slug, &then).await
then.to_rfc3339()
}
_ => sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
.fetch_one(&db)
.await
.unwrap_or(1),
_ => "all".to_string(),
};
format!("{count}")
let hits = sqlx::query_scalar!(
"select count(*) from hits where page = ? and accessed > ?",
slug,
when
)
.fetch_one(&db)
.await
.unwrap_or(0);
format!("{hits}")
}
async fn get_all_hits(State(db): State<SqlitePool>, Path(slug): Path<String>) -> String {
let hits = all_hits_helper(&db, &slug).await;
format!("{hits}")
}
//-************************************************************************
// li'l helpers
//-************************************************************************
#[derive(Debug, Parser)]
#[clap(version, about)]
struct Cli {
#[clap(
long,
short,
help = "Path to environment file.",
default_value = ".env"
)]
pub env: OsString,
async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 {
sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
.fetch_one(db)
.await
.unwrap_or(1)
}
fn init() {
@ -221,14 +236,3 @@ fn shasum(input: &[u8]) -> Vec<u8> {
context.update(input);
context.finish().as_ref().to_vec()
}
async fn get_period_hits(db: &SqlitePool, page: &str, when: &str) -> i32 {
sqlx::query_scalar!(
"select count(*) from hits where page = ? and accessed > ?",
page,
when
)
.fetch_one(db)
.await
.unwrap_or(0)
}