hitman/src/main.rs

244 lines
6.9 KiB
Rust

use std::{
env::VarError,
ffi::OsString,
io::Write,
net::{Ipv4Addr, SocketAddr},
};
use axum::{
extract::{Path, State},
http::{method::Method, HeaderValue},
routing::get,
Router,
};
use axum_client_ip::InsecureClientIp;
use clap::Parser;
use lazy_static::lazy_static;
use ring::digest::{Context, SHA256};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use tokio::net::TcpListener;
use tower_http::cors::{self, CorsLayer};
lazy_static! {
static ref HITMAN_ORIGIN: String =
std::env::var("HITMAN_ORIGIN").expect("could not get origin for service");
static ref SESSION_SALT: u64 = rand::random();
}
#[derive(Debug, Parser)]
#[clap(version, about)]
struct Cli {
#[clap(
long,
short,
help = "Path to environment file.",
default_value = ".env"
)]
pub env: OsString,
}
#[tokio::main]
async fn main() {
init();
let pool = db().await;
let origin = HeaderValue::from_str(&HITMAN_ORIGIN).unwrap();
let cors_layer = CorsLayer::new()
.allow_origin(origin)
.allow_methods(cors::AllowMethods::exact(Method::GET));
let app = Router::new()
.route("/hit/:slug", get(register_hit))
.route("/hits/:slug", get(get_all_hits))
.route("/hits/:slug/:period", get(get_period_hits))
.layer(cors_layer)
.with_state(pool.clone())
// we need the connect info in order to extract the IP address
.into_make_service_with_connect_info::<SocketAddr>();
let listener = mklistener().await;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
pool.close().await;
}
//-************************************************************************
// the three route hanlders
//-************************************************************************
/// This is the main handler. It counts the hit and returns the latest count.
async fn register_hit(
Path(slug): Path<String>,
State(db): State<SqlitePool>,
InsecureClientIp(ip): InsecureClientIp,
) -> String {
let slug = &slug;
let host = ip.to_string();
let now = chrono::Utc::now();
let now = now.to_rfc3339();
// What we really want is just the date + hour in 24-hour format; this limits
// duplicate views from the same host to one per hour:
let now = now.split(':').take(1).next().unwrap();
// the salt here is regenerated every time the service restarts, and guarantees
// we can't just enumerate all the possible hashes based on IP, page, and
// time alone.
let salt = *SESSION_SALT;
let key = format!("{now}{host}{slug}{salt}").into_bytes();
let key = hex::encode(shasum(&key));
let tx = db.begin().await;
if let Ok(mut tx) = tx {
match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
.execute(&mut *tx)
.await
{
Ok(_) => tx.commit().await.unwrap_or_default(),
_ => { /* whatevs, fine */ }
}
}
let hits = all_hits_helper(&db, slug).await;
format!("{hits}")
}
async fn get_period_hits(
State(db): State<SqlitePool>,
Path(slug): Path<String>,
Path(period): Path<String>,
) -> String {
let now = chrono::Utc::now();
let slug = &slug;
let when = match period.as_str() {
"day" => {
let then = now - chrono::Duration::try_hours(24).unwrap();
then.to_rfc3339()
}
"week" => {
let then = now - chrono::Duration::try_days(7).unwrap();
then.to_rfc3339()
}
_ => {
let then = now - chrono::Duration::try_days(365_242).unwrap(); // 1000 years
then.to_rfc3339()
}
};
let hits = sqlx::query_scalar!(
"select count(*) from hits where page = ? and viewed > ?",
slug,
when
)
.fetch_one(&db)
.await
.unwrap_or(0);
format!("{hits}")
}
// it's easier to split this into something that handles the request parameters
// and something that does the query
async fn get_all_hits(State(db): State<SqlitePool>, Path(slug): Path<String>) -> String {
let hits = all_hits_helper(&db, &slug).await;
format!("{hits}")
}
//-************************************************************************
// li'l helpers
//-************************************************************************
async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 {
sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
.fetch_one(db)
.await
.unwrap_or(1)
}
fn init() {
let cli = Cli::parse();
dotenvy::from_path_override(cli.env).expect("Could not read .env file.");
env_logger::builder()
.format(|buf, record| {
let ts = buf.timestamp();
writeln!(buf, "{}: {}", ts, record.args())
})
.init();
}
async fn db() -> SqlitePool {
let dbfile = std::env::var("DATABASE_FILE").unwrap();
let opts = SqliteConnectOptions::new()
.foreign_keys(true)
.create_if_missing(true)
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.filename(&dbfile)
.optimize_on_close(true, None);
let pool = SqlitePoolOptions::new()
.connect_with(opts.clone())
.await
.unwrap();
sqlx::migrate!().run(&pool).await.unwrap();
let count = sqlx::query_scalar!("select count(*) from hits")
.fetch_one(&pool)
.await
.expect("could not get hit count from DB");
log::info!("Connected to DB, found {count} total hits.");
pool
}
async fn mklistener() -> TcpListener {
let ip =
std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment");
let ip: Ipv4Addr = ip
.parse()
.unwrap_or_else(|_| panic!("Could not parse {ip} as an IP address"));
let port: u16 = std::env::var("LISTENING_PORT")
.and_then(|p| p.parse().map_err(|_| VarError::NotPresent))
.unwrap_or_else(|_| {
panic!("Could not find LISTENING_PORT in env or parse if present");
});
let addr = SocketAddr::from((ip, port));
TcpListener::bind(&addr).await.unwrap()
}
async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {log::info!("shutting down")},
_ = terminate => {},
}
}
fn shasum(input: &[u8]) -> Vec<u8> {
let mut context = Context::new(&SHA256);
context.update(input);
context.finish().as_ref().to_vec()
}