use std::{ env::VarError, ffi::OsString, io::Write, net::{Ipv4Addr, SocketAddr}, }; use axum::{ extract::{Path, State}, http::{header::REFERER, HeaderMap}, routing::get, Router, }; use clap::Parser; use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; use tokio::net::TcpListener; use url::Url; #[tokio::main] async fn main() { init(); let pool = db().await; let app = Router::new() .route("/hit", get(register_hit)) .route("/hits", get(get_hits)) .route("/hits/:period", get(get_hits)) .with_state(pool.clone()) .into_make_service(); let listener = mklistener().await; axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); pool.close().await; } async fn register_hit(State(db): State, headers: HeaderMap) -> String { let now = chrono::Utc::now(); let referer = headers.get(REFERER); let page = if let Some(referer) = referer { let p = referer.to_str().unwrap_or("/").to_string(); if let Ok(path) = Url::parse(&p) { path.path().to_string() } else { return "".to_string(); } } else { return "".to_string(); }; let page = &page; let now = now.to_rfc3339(); let now = &now; sqlx::query!("insert into hits (page, accessed) values (?, ?)", page, now) .execute(&db) .await .unwrap_or_default(); "".to_string() } #[axum::debug_handler] async fn get_hits( State(db): State, period: Option>, headers: HeaderMap, ) -> String { let now = chrono::Utc::now(); let referer = headers.get(REFERER); let page = if let Some(referer) = referer { let p = referer.to_str().unwrap_or("/").to_string(); if let Ok(path) = Url::parse(&p) { path.path().to_string() } else { return "".to_string(); } } else { return "".to_string(); }; let page = &page; let count = match period.unwrap_or(Path("all".to_string())).as_str() { "day" => { let then = now - chrono::Duration::try_hours(24).unwrap(); let then = then.to_rfc3339(); get_period_hits(&db, page, &then).await } "week" => { let then = now - chrono::Duration::try_days(7).unwrap(); let then = then.to_rfc3339(); get_period_hits(&db, page, &then).await } _ => sqlx::query_scalar!("select count(*) from hits where page = ?", page) .fetch_one(&db) .await .unwrap_or(1), }; format!("{count}") } async fn get_period_hits(db: &SqlitePool, page: &str, when: &str) -> i32 { sqlx::query_scalar!( "select count(*) from hits where page = ? and accessed > ?", page, when ) .fetch_one(db) .await .unwrap_or(0) } //-************************************************************************ // li'l helpers //-************************************************************************ #[derive(Debug, Parser)] #[clap(version, about)] struct Cli { #[clap( long, short, help = "Path to environment file.", default_value = ".env" )] pub env: OsString, } fn init() { let cli = Cli::parse(); dotenvy::from_path_override(cli.env).expect("Could not read .env file."); env_logger::builder() .format(|buf, record| { let ts = buf.timestamp(); writeln!(buf, "{}: {}", ts, record.args()) }) .init(); } async fn db() -> SqlitePool { let dbfile = std::env::var("DATABASE_FILE").unwrap(); let opts = SqliteConnectOptions::new() .foreign_keys(true) .create_if_missing(true) .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) .filename(&dbfile) .optimize_on_close(true, None); let pool = SqlitePoolOptions::new() .connect_with(opts.clone()) .await .unwrap(); sqlx::migrate!().run(&pool).await.unwrap(); let count = sqlx::query_scalar!("select count(*) from hits") .fetch_one(&pool) .await .expect("could not get hit count from DB"); log::info!("Connected to DB, found {count} total hits."); pool } async fn mklistener() -> TcpListener { let ip = std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment"); let ip: Ipv4Addr = ip .parse() .unwrap_or_else(|_| panic!("Could not parse {ip} as an IP address")); let port: u16 = std::env::var("LISTENING_PORT") .and_then(|p| p.parse().map_err(|_| VarError::NotPresent)) .unwrap_or_else(|_| { panic!("Could not find LISTENING_PORT in env or parse if present"); }); let addr = SocketAddr::from((ip, port)); TcpListener::bind(&addr).await.unwrap() } async fn shutdown_signal() { use tokio::signal; let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {log::info!("shutting down")}, _ = terminate => {}, } }