244 lines
6.9 KiB
Rust
244 lines
6.9 KiB
Rust
use std::{
|
|
env::VarError,
|
|
ffi::OsString,
|
|
io::Write,
|
|
net::{Ipv4Addr, SocketAddr},
|
|
};
|
|
|
|
use axum::{
|
|
extract::{Path, State},
|
|
http::{method::Method, HeaderValue},
|
|
routing::get,
|
|
Router,
|
|
};
|
|
use axum_client_ip::InsecureClientIp;
|
|
use clap::Parser;
|
|
use lazy_static::lazy_static;
|
|
use ring::digest::{Context, SHA256};
|
|
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
|
|
use tokio::net::TcpListener;
|
|
use tower_http::cors::{self, CorsLayer};
|
|
|
|
lazy_static! {
|
|
static ref HITMAN_ORIGIN: String =
|
|
std::env::var("HITMAN_ORIGIN").expect("could not get origin for service");
|
|
static ref SESSION_SALT: u64 = rand::random();
|
|
}
|
|
|
|
#[derive(Debug, Parser)]
|
|
#[clap(version, about)]
|
|
struct Cli {
|
|
#[clap(
|
|
long,
|
|
short,
|
|
help = "Path to environment file.",
|
|
default_value = ".env"
|
|
)]
|
|
pub env: OsString,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
init();
|
|
|
|
let pool = db().await;
|
|
let origin = HeaderValue::from_str(&HITMAN_ORIGIN).unwrap();
|
|
let cors_layer = CorsLayer::new()
|
|
.allow_origin(origin)
|
|
.allow_methods(cors::AllowMethods::exact(Method::GET));
|
|
|
|
let app = Router::new()
|
|
.route("/hit/:slug", get(register_hit))
|
|
.route("/hits/:slug", get(get_all_hits))
|
|
.route("/hits/:slug/:period", get(get_period_hits))
|
|
.layer(cors_layer)
|
|
.with_state(pool.clone())
|
|
// we need the connect info in order to extract the IP address
|
|
.into_make_service_with_connect_info::<SocketAddr>();
|
|
|
|
let listener = mklistener().await;
|
|
|
|
axum::serve(listener, app)
|
|
.with_graceful_shutdown(shutdown_signal())
|
|
.await
|
|
.unwrap();
|
|
|
|
pool.close().await;
|
|
}
|
|
|
|
//-************************************************************************
|
|
// the three route hanlders
|
|
//-************************************************************************
|
|
|
|
/// This is the main handler. It counts the hit and returns the latest count.
|
|
async fn register_hit(
|
|
Path(slug): Path<String>,
|
|
State(db): State<SqlitePool>,
|
|
InsecureClientIp(ip): InsecureClientIp,
|
|
) -> String {
|
|
let slug = &slug;
|
|
|
|
let host = ip.to_string();
|
|
|
|
let now = chrono::Utc::now();
|
|
let now = now.to_rfc3339();
|
|
// What we really want is just the date + hour in 24-hour format; this limits
|
|
// duplicate views from the same host to one per hour:
|
|
let now = now.split(':').take(1).next().unwrap();
|
|
|
|
// the salt here is regenerated every time the service restarts, and guarantees
|
|
// we can't just enumerate all the possible hashes based on IP, page, and
|
|
// time alone.
|
|
let salt = *SESSION_SALT;
|
|
let key = format!("{now}{host}{slug}{salt}").into_bytes();
|
|
let key = hex::encode(shasum(&key));
|
|
|
|
let tx = db.begin().await;
|
|
|
|
if let Ok(mut tx) = tx {
|
|
match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
|
|
.execute(&mut *tx)
|
|
.await
|
|
{
|
|
Ok(_) => tx.commit().await.unwrap_or_default(),
|
|
_ => { /* whatevs, fine */ }
|
|
}
|
|
}
|
|
|
|
let hits = all_hits_helper(&db, slug).await;
|
|
format!("{hits}")
|
|
}
|
|
|
|
async fn get_period_hits(
|
|
State(db): State<SqlitePool>,
|
|
Path(slug): Path<String>,
|
|
Path(period): Path<String>,
|
|
) -> String {
|
|
let now = chrono::Utc::now();
|
|
let slug = &slug;
|
|
|
|
let when = match period.as_str() {
|
|
"day" => {
|
|
let then = now - chrono::Duration::try_hours(24).unwrap();
|
|
then.to_rfc3339()
|
|
}
|
|
"week" => {
|
|
let then = now - chrono::Duration::try_days(7).unwrap();
|
|
then.to_rfc3339()
|
|
}
|
|
_ => {
|
|
let then = now - chrono::Duration::try_days(365_242).unwrap(); // 1000 years
|
|
then.to_rfc3339()
|
|
}
|
|
};
|
|
|
|
let hits = sqlx::query_scalar!(
|
|
"select count(*) from hits where page = ? and viewed > ?",
|
|
slug,
|
|
when
|
|
)
|
|
.fetch_one(&db)
|
|
.await
|
|
.unwrap_or(0);
|
|
|
|
format!("{hits}")
|
|
}
|
|
|
|
// it's easier to split this into something that handles the request parameters
|
|
// and something that does the query
|
|
async fn get_all_hits(State(db): State<SqlitePool>, Path(slug): Path<String>) -> String {
|
|
let hits = all_hits_helper(&db, &slug).await;
|
|
format!("{hits}")
|
|
}
|
|
|
|
//-************************************************************************
|
|
// li'l helpers
|
|
//-************************************************************************
|
|
async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 {
|
|
sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
|
|
.fetch_one(db)
|
|
.await
|
|
.unwrap_or(1)
|
|
}
|
|
|
|
fn init() {
|
|
let cli = Cli::parse();
|
|
dotenvy::from_path_override(cli.env).expect("Could not read .env file.");
|
|
env_logger::builder()
|
|
.format(|buf, record| {
|
|
let ts = buf.timestamp();
|
|
writeln!(buf, "{}: {}", ts, record.args())
|
|
})
|
|
.init();
|
|
}
|
|
|
|
async fn db() -> SqlitePool {
|
|
let dbfile = std::env::var("DATABASE_FILE").unwrap();
|
|
let opts = SqliteConnectOptions::new()
|
|
.foreign_keys(true)
|
|
.create_if_missing(true)
|
|
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
|
.filename(&dbfile)
|
|
.optimize_on_close(true, None);
|
|
let pool = SqlitePoolOptions::new()
|
|
.connect_with(opts.clone())
|
|
.await
|
|
.unwrap();
|
|
|
|
sqlx::migrate!().run(&pool).await.unwrap();
|
|
|
|
let count = sqlx::query_scalar!("select count(*) from hits")
|
|
.fetch_one(&pool)
|
|
.await
|
|
.expect("could not get hit count from DB");
|
|
log::info!("Connected to DB, found {count} total hits.");
|
|
|
|
pool
|
|
}
|
|
|
|
async fn mklistener() -> TcpListener {
|
|
let ip =
|
|
std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment");
|
|
let ip: Ipv4Addr = ip
|
|
.parse()
|
|
.unwrap_or_else(|_| panic!("Could not parse {ip} as an IP address"));
|
|
let port: u16 = std::env::var("LISTENING_PORT")
|
|
.and_then(|p| p.parse().map_err(|_| VarError::NotPresent))
|
|
.unwrap_or_else(|_| {
|
|
panic!("Could not find LISTENING_PORT in env or parse if present");
|
|
});
|
|
let addr = SocketAddr::from((ip, port));
|
|
TcpListener::bind(&addr).await.unwrap()
|
|
}
|
|
|
|
async fn shutdown_signal() {
|
|
use tokio::signal;
|
|
let ctrl_c = async {
|
|
signal::ctrl_c()
|
|
.await
|
|
.expect("failed to install Ctrl+C handler");
|
|
};
|
|
|
|
#[cfg(unix)]
|
|
let terminate = async {
|
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
|
.expect("failed to install signal handler")
|
|
.recv()
|
|
.await;
|
|
};
|
|
|
|
#[cfg(not(unix))]
|
|
let terminate = std::future::pending::<()>();
|
|
|
|
tokio::select! {
|
|
_ = ctrl_c => {log::info!("shutting down")},
|
|
_ = terminate => {},
|
|
}
|
|
}
|
|
|
|
fn shasum(input: &[u8]) -> Vec<u8> {
|
|
let mut context = Context::new(&SHA256);
|
|
context.update(input);
|
|
context.finish().as_ref().to_vec()
|
|
}
|