really use remote IP for hashing
This commit is contained in:
parent
e96a993f9a
commit
f476d6a117
3 changed files with 88 additions and 55 deletions
28
Cargo.lock
generated
28
Cargo.lock
generated
|
@ -99,6 +99,17 @@ dependencies = [
|
|||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-client-ip"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e7c467bdcd2bd982ce5c8742a1a178aba7b03db399fd18f5d5d438f5aa91cb4"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"forwarded-header-value",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.3"
|
||||
|
@ -439,6 +450,16 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "forwarded-header-value"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
|
||||
dependencies = [
|
||||
"nonempty",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.30"
|
||||
|
@ -608,6 +629,7 @@ name = "hitman"
|
|||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"axum-client-ip",
|
||||
"chrono",
|
||||
"clap",
|
||||
"dotenvy",
|
||||
|
@ -891,6 +913,12 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonempty"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint-dig"
|
||||
version = "0.8.4"
|
||||
|
|
|
@ -11,6 +11,7 @@ repository = "https://git.kittencollective.com/nebkor/hitman"
|
|||
|
||||
[dependencies]
|
||||
axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "macros"] }
|
||||
axum-client-ip = "0.5"
|
||||
chrono = { version = "0.4", default-features = false, features = ["now"] }
|
||||
clap = { version = "4.5", default-features = false, features = ["std", "derive", "unicode", "help", "usage"] }
|
||||
dotenvy = { version = "0.15", default-features = false }
|
||||
|
|
114
src/main.rs
114
src/main.rs
|
@ -6,12 +6,12 @@ use std::{
|
|||
};
|
||||
|
||||
use axum::{
|
||||
debug_handler,
|
||||
extract::{Path, State},
|
||||
http::{method::Method, HeaderMap, HeaderValue},
|
||||
http::{method::Method, HeaderValue},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use clap::Parser;
|
||||
use lazy_static::lazy_static;
|
||||
use ring::digest::{Context, SHA256};
|
||||
|
@ -25,6 +25,18 @@ lazy_static! {
|
|||
static ref SESSION_SALT: u64 = rand::random();
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[clap(version, about)]
|
||||
struct Cli {
|
||||
#[clap(
|
||||
long,
|
||||
short,
|
||||
help = "Path to environment file.",
|
||||
default_value = ".env"
|
||||
)]
|
||||
pub env: OsString,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
init();
|
||||
|
@ -37,12 +49,15 @@ async fn main() {
|
|||
|
||||
let app = Router::new()
|
||||
.route("/hit/:slug", get(register_hit))
|
||||
.route("/hits/:slug", get(get_hits))
|
||||
.route("/hits/:slug/:period", get(get_hits))
|
||||
.route("/hits/:slug", get(get_all_hits))
|
||||
.route("/hits/:slug/:period", get(get_period_hits))
|
||||
.layer(cors_layer)
|
||||
.with_state(pool.clone())
|
||||
.into_make_service();
|
||||
// we need the connect info in order to extract the IP address
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
|
||||
let listener = mklistener().await;
|
||||
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
|
@ -56,17 +71,14 @@ async fn main() {
|
|||
//-************************************************************************
|
||||
|
||||
/// This is the main handler. It counts the hit and returns the latest count.
|
||||
#[debug_handler]
|
||||
async fn register_hit(
|
||||
Path(slug): Path<String>,
|
||||
State(db): State<SqlitePool>,
|
||||
req: HeaderMap,
|
||||
InsecureClientIp(ip): InsecureClientIp,
|
||||
) -> String {
|
||||
let host = req
|
||||
.get("host")
|
||||
.cloned()
|
||||
.unwrap_or(HeaderValue::from_str("").unwrap());
|
||||
let host = host.to_str().unwrap_or("");
|
||||
let slug = &slug;
|
||||
|
||||
let host = ip.to_string();
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let now = now.to_rfc3339();
|
||||
|
@ -74,6 +86,9 @@ async fn register_hit(
|
|||
// duplicate views from the same host to one per hour:
|
||||
let now = now.split(':').take(1).next().unwrap();
|
||||
|
||||
// the salt here is regenerated every time the service restarts, and guarantees
|
||||
// we can't just enumerate all the possible hashes based on IP, page, and
|
||||
// time alone.
|
||||
let salt = *SESSION_SALT;
|
||||
let key = format!("{now}{host}{slug}{salt}").into_bytes();
|
||||
let key = hex::encode(shasum(&key));
|
||||
|
@ -81,64 +96,64 @@ async fn register_hit(
|
|||
let tx = db.begin().await;
|
||||
|
||||
if let Ok(mut tx) = tx {
|
||||
sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
|
||||
match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
tx.commit().await.unwrap_or_default();
|
||||
{
|
||||
Ok(_) => tx.commit().await.unwrap_or_default(),
|
||||
_ => { /* whatevs, fine */ }
|
||||
}
|
||||
}
|
||||
|
||||
let hits = sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
|
||||
.fetch_one(&db)
|
||||
.await
|
||||
.unwrap_or(1);
|
||||
|
||||
let hits = all_hits_helper(&db, slug).await;
|
||||
format!("{hits}")
|
||||
}
|
||||
|
||||
/// fer fancy people what want to be a little finer
|
||||
#[axum::debug_handler]
|
||||
async fn get_hits(
|
||||
async fn get_period_hits(
|
||||
State(db): State<SqlitePool>,
|
||||
Path(slug): Path<String>,
|
||||
period: Option<Path<String>>,
|
||||
Path(period): Path<String>,
|
||||
) -> String {
|
||||
let now = chrono::Utc::now();
|
||||
let slug = &slug;
|
||||
|
||||
let count = match period.unwrap_or(Path("all".to_string())).as_str() {
|
||||
let when = match period.as_str() {
|
||||
"day" => {
|
||||
let then = now - chrono::Duration::try_hours(24).unwrap();
|
||||
let then = then.to_rfc3339();
|
||||
get_period_hits(&db, slug, &then).await
|
||||
then.to_rfc3339()
|
||||
}
|
||||
"week" => {
|
||||
let then = now - chrono::Duration::try_days(7).unwrap();
|
||||
let then = then.to_rfc3339();
|
||||
get_period_hits(&db, slug, &then).await
|
||||
then.to_rfc3339()
|
||||
}
|
||||
_ => sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
|
||||
.fetch_one(&db)
|
||||
.await
|
||||
.unwrap_or(1),
|
||||
_ => "all".to_string(),
|
||||
};
|
||||
|
||||
format!("{count}")
|
||||
let hits = sqlx::query_scalar!(
|
||||
"select count(*) from hits where page = ? and accessed > ?",
|
||||
slug,
|
||||
when
|
||||
)
|
||||
.fetch_one(&db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
format!("{hits}")
|
||||
}
|
||||
|
||||
async fn get_all_hits(State(db): State<SqlitePool>, Path(slug): Path<String>) -> String {
|
||||
let hits = all_hits_helper(&db, &slug).await;
|
||||
format!("{hits}")
|
||||
}
|
||||
|
||||
//-************************************************************************
|
||||
// li'l helpers
|
||||
//-************************************************************************
|
||||
#[derive(Debug, Parser)]
|
||||
#[clap(version, about)]
|
||||
struct Cli {
|
||||
#[clap(
|
||||
long,
|
||||
short,
|
||||
help = "Path to environment file.",
|
||||
default_value = ".env"
|
||||
)]
|
||||
pub env: OsString,
|
||||
async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 {
|
||||
sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
|
||||
.fetch_one(db)
|
||||
.await
|
||||
.unwrap_or(1)
|
||||
}
|
||||
|
||||
fn init() {
|
||||
|
@ -221,14 +236,3 @@ fn shasum(input: &[u8]) -> Vec<u8> {
|
|||
context.update(input);
|
||||
context.finish().as_ref().to_vec()
|
||||
}
|
||||
|
||||
async fn get_period_hits(db: &SqlitePool, page: &str, when: &str) -> i32 {
|
||||
sqlx::query_scalar!(
|
||||
"select count(*) from hits where page = ? and accessed > ?",
|
||||
page,
|
||||
when
|
||||
)
|
||||
.fetch_one(db)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue