use std::{ env::VarError, ffi::OsString, io::Write, net::{Ipv4Addr, SocketAddr}, }; use axum::{ extract::{Path, State}, http::{method::Method, HeaderValue}, routing::get, Router, }; use axum_client_ip::InsecureClientIp; use clap::Parser; use lazy_static::lazy_static; use ring::digest::{Context, SHA256}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; use tokio::net::TcpListener; use tower_http::cors::{self, CorsLayer}; lazy_static! { static ref HITMAN_ORIGIN: String = std::env::var("HITMAN_ORIGIN").expect("could not get origin for service"); static ref SESSION_SALT: u64 = rand::random(); } #[derive(Debug, Parser)] #[clap(version, about)] struct Cli { #[clap( long, short, help = "Path to environment file.", default_value = ".env" )] pub env: OsString, } #[tokio::main] async fn main() { init(); let pool = db().await; let origin = HeaderValue::from_str(&HITMAN_ORIGIN).unwrap(); let cors_layer = CorsLayer::new() .allow_origin(origin) .allow_methods(cors::AllowMethods::exact(Method::GET)); let app = Router::new() .route("/hit/:slug", get(register_hit)) .route("/hits/:slug", get(get_all_hits)) .route("/hits/:slug/:period", get(get_period_hits)) .layer(cors_layer) .with_state(pool.clone()) // we need the connect info in order to extract the IP address .into_make_service_with_connect_info::(); let listener = mklistener().await; axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); pool.close().await; } //-************************************************************************ // the three route hanlders //-************************************************************************ /// This is the main handler. It counts the hit and returns the latest count. async fn register_hit( Path(slug): Path, State(db): State, InsecureClientIp(ip): InsecureClientIp, ) -> String { let slug = &slug; let host = ip.to_string(); let now = chrono::Utc::now(); let now = now.to_rfc3339(); // What we really want is just the date + hour in 24-hour format; this limits // duplicate views from the same host to one per hour: let now = now.split(':').take(1).next().unwrap(); // the salt here is regenerated every time the service restarts, and guarantees // we can't just enumerate all the possible hashes based on IP, page, and // time alone. let salt = *SESSION_SALT; let key = format!("{now}{host}{slug}{salt}").into_bytes(); let key = hex::encode(shasum(&key)); let tx = db.begin().await; if let Ok(mut tx) = tx { match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,) .execute(&mut *tx) .await { Ok(_) => tx.commit().await.unwrap_or_default(), _ => { /* whatevs, fine */ } } } let hits = all_hits_helper(&db, slug).await; format!("{hits}") } async fn get_period_hits( State(db): State, Path(slug): Path, Path(period): Path, ) -> String { let now = chrono::Utc::now(); let slug = &slug; let when = match period.as_str() { "day" => { let then = now - chrono::Duration::try_hours(24).unwrap(); then.to_rfc3339() } "week" => { let then = now - chrono::Duration::try_days(7).unwrap(); then.to_rfc3339() } _ => { let then = now - chrono::Duration::try_days(365_242).unwrap(); // 1000 years then.to_rfc3339() } }; let hits = sqlx::query_scalar!( "select count(*) from hits where page = ? and viewed > ?", slug, when ) .fetch_one(&db) .await .unwrap_or(0); format!("{hits}") } // it's easier to split this into something that handles the request parameters // and something that does the query async fn get_all_hits(State(db): State, Path(slug): Path) -> String { let hits = all_hits_helper(&db, &slug).await; format!("{hits}") } //-************************************************************************ // li'l helpers //-************************************************************************ async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 { sqlx::query_scalar!("select count(*) from hits where page = ?", slug) .fetch_one(db) .await .unwrap_or(1) } fn init() { let cli = Cli::parse(); dotenvy::from_path_override(cli.env).expect("Could not read .env file."); env_logger::builder() .format(|buf, record| { let ts = buf.timestamp(); writeln!(buf, "{}: {}", ts, record.args()) }) .init(); } async fn db() -> SqlitePool { let dbfile = std::env::var("DATABASE_FILE").unwrap(); let opts = SqliteConnectOptions::new() .foreign_keys(true) .create_if_missing(true) .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) .filename(&dbfile) .optimize_on_close(true, None); let pool = SqlitePoolOptions::new() .connect_with(opts.clone()) .await .unwrap(); sqlx::migrate!().run(&pool).await.unwrap(); let count = sqlx::query_scalar!("select count(*) from hits") .fetch_one(&pool) .await .expect("could not get hit count from DB"); log::info!("Connected to DB, found {count} total hits."); pool } async fn mklistener() -> TcpListener { let ip = std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment"); let ip: Ipv4Addr = ip .parse() .unwrap_or_else(|_| panic!("Could not parse {ip} as an IP address")); let port: u16 = std::env::var("LISTENING_PORT") .and_then(|p| p.parse().map_err(|_| VarError::NotPresent)) .unwrap_or_else(|_| { panic!("Could not find LISTENING_PORT in env or parse if present"); }); let addr = SocketAddr::from((ip, port)); TcpListener::bind(&addr).await.unwrap() } async fn shutdown_signal() { use tokio::signal; let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {log::info!("shutting down")}, _ = terminate => {}, } } fn shasum(input: &[u8]) -> Vec { let mut context = Context::new(&SHA256); context.update(input); context.finish().as_ref().to_vec() }