235 lines
6.5 KiB
Rust
235 lines
6.5 KiB
Rust
use std::{
|
|
env::VarError,
|
|
ffi::OsString,
|
|
io::Write,
|
|
net::{Ipv4Addr, SocketAddr},
|
|
};
|
|
|
|
use axum::{
|
|
debug_handler,
|
|
extract::{Path, State},
|
|
http::{method::Method, HeaderMap, HeaderValue},
|
|
routing::get,
|
|
Router,
|
|
};
|
|
use clap::Parser;
|
|
use lazy_static::lazy_static;
|
|
use ring::digest::{Context, SHA256};
|
|
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
|
|
use tokio::net::TcpListener;
|
|
use tower_http::cors::{self, CorsLayer};
|
|
|
|
lazy_static! {
|
|
static ref HITMAN_ORIGIN: String =
|
|
std::env::var("HITMAN_ORIGIN").expect("could not get origin for service");
|
|
static ref SESSION_SALT: u64 = rand::random();
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
init();
|
|
|
|
let pool = db().await;
|
|
let origin = HeaderValue::from_str(&HITMAN_ORIGIN).unwrap();
|
|
let cors_layer = CorsLayer::new()
|
|
.allow_origin(origin)
|
|
.allow_methods(cors::AllowMethods::exact(Method::GET));
|
|
|
|
let app = Router::new()
|
|
.route("/hit/:slug", get(register_hit))
|
|
.route("/hits/:slug", get(get_hits))
|
|
.route("/hits/:slug/:period", get(get_hits))
|
|
.layer(cors_layer)
|
|
.with_state(pool.clone())
|
|
.into_make_service();
|
|
let listener = mklistener().await;
|
|
axum::serve(listener, app)
|
|
.with_graceful_shutdown(shutdown_signal())
|
|
.await
|
|
.unwrap();
|
|
|
|
pool.close().await;
|
|
}
|
|
|
|
//-************************************************************************
|
|
// the two route hanlders
|
|
//-************************************************************************
|
|
|
|
/// This is the main handler. It counts the hit and returns the latest count.
|
|
#[debug_handler]
|
|
async fn register_hit(
|
|
Path(slug): Path<String>,
|
|
State(db): State<SqlitePool>,
|
|
req: HeaderMap,
|
|
) -> String {
|
|
let host = req
|
|
.get("host")
|
|
.cloned()
|
|
.unwrap_or(HeaderValue::from_str("").unwrap());
|
|
let host = host.to_str().unwrap_or("");
|
|
|
|
let now = chrono::Utc::now();
|
|
let now = now.to_rfc3339();
|
|
// What we really want is just the date + hour in 24-hour format; this limits
|
|
// duplicate views from the same host to one per hour:
|
|
let now = now.split(':').take(1).next().unwrap();
|
|
|
|
let salt = *SESSION_SALT;
|
|
let key = format!("{now}{host}{slug}{salt}").into_bytes();
|
|
let key = hex::encode(shasum(&key));
|
|
|
|
let tx = db.begin().await;
|
|
|
|
if let Ok(mut tx) = tx {
|
|
sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
|
|
.execute(&mut *tx)
|
|
.await
|
|
.unwrap_or_default();
|
|
tx.commit().await.unwrap_or_default();
|
|
}
|
|
|
|
let hits = sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
|
|
.fetch_one(&db)
|
|
.await
|
|
.unwrap_or(1);
|
|
|
|
format!("{hits}")
|
|
}
|
|
|
|
/// fer fancy people what want to be a little finer
|
|
#[axum::debug_handler]
|
|
async fn get_hits(
|
|
State(db): State<SqlitePool>,
|
|
Path(slug): Path<String>,
|
|
period: Option<Path<String>>,
|
|
) -> String {
|
|
let now = chrono::Utc::now();
|
|
let slug = &slug;
|
|
|
|
let count = match period.unwrap_or(Path("all".to_string())).as_str() {
|
|
"day" => {
|
|
let then = now - chrono::Duration::try_hours(24).unwrap();
|
|
let then = then.to_rfc3339();
|
|
get_period_hits(&db, slug, &then).await
|
|
}
|
|
"week" => {
|
|
let then = now - chrono::Duration::try_days(7).unwrap();
|
|
let then = then.to_rfc3339();
|
|
get_period_hits(&db, slug, &then).await
|
|
}
|
|
_ => sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
|
|
.fetch_one(&db)
|
|
.await
|
|
.unwrap_or(1),
|
|
};
|
|
|
|
format!("{count}")
|
|
}
|
|
|
|
//-************************************************************************
|
|
// li'l helpers
|
|
//-************************************************************************
|
|
#[derive(Debug, Parser)]
|
|
#[clap(version, about)]
|
|
struct Cli {
|
|
#[clap(
|
|
long,
|
|
short,
|
|
help = "Path to environment file.",
|
|
default_value = ".env"
|
|
)]
|
|
pub env: OsString,
|
|
}
|
|
|
|
fn init() {
|
|
let cli = Cli::parse();
|
|
dotenvy::from_path_override(cli.env).expect("Could not read .env file.");
|
|
env_logger::builder()
|
|
.format(|buf, record| {
|
|
let ts = buf.timestamp();
|
|
writeln!(buf, "{}: {}", ts, record.args())
|
|
})
|
|
.init();
|
|
}
|
|
|
|
async fn db() -> SqlitePool {
|
|
let dbfile = std::env::var("DATABASE_FILE").unwrap();
|
|
let opts = SqliteConnectOptions::new()
|
|
.foreign_keys(true)
|
|
.create_if_missing(true)
|
|
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
|
.filename(&dbfile)
|
|
.optimize_on_close(true, None);
|
|
let pool = SqlitePoolOptions::new()
|
|
.connect_with(opts.clone())
|
|
.await
|
|
.unwrap();
|
|
|
|
sqlx::migrate!().run(&pool).await.unwrap();
|
|
|
|
let count = sqlx::query_scalar!("select count(*) from hits")
|
|
.fetch_one(&pool)
|
|
.await
|
|
.expect("could not get hit count from DB");
|
|
log::info!("Connected to DB, found {count} total hits.");
|
|
|
|
pool
|
|
}
|
|
|
|
async fn mklistener() -> TcpListener {
|
|
let ip =
|
|
std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment");
|
|
let ip: Ipv4Addr = ip
|
|
.parse()
|
|
.unwrap_or_else(|_| panic!("Could not parse {ip} as an IP address"));
|
|
let port: u16 = std::env::var("LISTENING_PORT")
|
|
.and_then(|p| p.parse().map_err(|_| VarError::NotPresent))
|
|
.unwrap_or_else(|_| {
|
|
panic!("Could not find LISTENING_PORT in env or parse if present");
|
|
});
|
|
let addr = SocketAddr::from((ip, port));
|
|
TcpListener::bind(&addr).await.unwrap()
|
|
}
|
|
|
|
async fn shutdown_signal() {
|
|
use tokio::signal;
|
|
let ctrl_c = async {
|
|
signal::ctrl_c()
|
|
.await
|
|
.expect("failed to install Ctrl+C handler");
|
|
};
|
|
|
|
#[cfg(unix)]
|
|
let terminate = async {
|
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
|
.expect("failed to install signal handler")
|
|
.recv()
|
|
.await;
|
|
};
|
|
|
|
#[cfg(not(unix))]
|
|
let terminate = std::future::pending::<()>();
|
|
|
|
tokio::select! {
|
|
_ = ctrl_c => {log::info!("shutting down")},
|
|
_ = terminate => {},
|
|
}
|
|
}
|
|
|
|
fn shasum(input: &[u8]) -> Vec<u8> {
|
|
let mut context = Context::new(&SHA256);
|
|
context.update(input);
|
|
context.finish().as_ref().to_vec()
|
|
}
|
|
|
|
async fn get_period_hits(db: &SqlitePool, page: &str, when: &str) -> i32 {
|
|
sqlx::query_scalar!(
|
|
"select count(*) from hits where page = ? and accessed > ?",
|
|
page,
|
|
when
|
|
)
|
|
.fetch_one(db)
|
|
.await
|
|
.unwrap_or(0)
|
|
}
|