use std::{
    env::VarError,
    ffi::OsString,
    io::Write,
    net::{Ipv4Addr, SocketAddr},
};

use axum::{
    extract::{Path, State},
    http::{method::Method, HeaderValue},
    routing::get,
    Router,
};
use axum_client_ip::InsecureClientIp;
use clap::Parser;
use lazy_static::lazy_static;
use ring::digest::{Context, SHA256};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use tokio::net::TcpListener;
use tower_http::cors::{self, CorsLayer};

lazy_static! {
    static ref HITMAN_ORIGIN: String =
        std::env::var("HITMAN_ORIGIN").expect("could not get origin for service");
    static ref SESSION_SALT: u64 = rand::random();
}

#[derive(Debug, Parser)]
#[clap(version, about)]
struct Cli {
    #[clap(
        long,
        short,
        help = "Path to environment file.",
        default_value = ".env"
    )]
    pub env: OsString,
}

#[tokio::main]
async fn main() {
    init();

    let pool = db().await;
    let origin = HeaderValue::from_str(&HITMAN_ORIGIN).unwrap();
    let cors_layer = CorsLayer::new()
        .allow_origin(origin)
        .allow_methods(cors::AllowMethods::exact(Method::GET));

    let app = Router::new()
        .route("/hit/:slug", get(register_hit))
        .route("/hits/:slug", get(get_all_hits))
        .route("/hits/:slug/:period", get(get_period_hits))
        .layer(cors_layer)
        .with_state(pool.clone())
        // we need the connect info in order to extract the IP address
        .into_make_service_with_connect_info::<SocketAddr>();

    let listener = mklistener().await;

    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .unwrap();

    pool.close().await;
}

//-************************************************************************
// the three route hanlders
//-************************************************************************

/// This is the main handler. It counts the hit and returns the latest count.
async fn register_hit(
    Path(slug): Path<String>,
    State(db): State<SqlitePool>,
    InsecureClientIp(ip): InsecureClientIp,
) -> String {
    let slug = &slug;

    let host = ip.to_string();

    let now = chrono::Utc::now();
    let now = now.to_rfc3339();
    // What we really want is just the date + hour in 24-hour format; this limits
    // duplicate views from the same host to one per hour:
    let now = now.split(':').take(1).next().unwrap();

    // the salt here is regenerated every time the service restarts, and guarantees
    // we can't just enumerate all the possible hashes based on IP, page, and
    // time alone.
    let salt = *SESSION_SALT;
    let key = format!("{now}{host}{slug}{salt}").into_bytes();
    let key = hex::encode(shasum(&key));

    let tx = db.begin().await;

    if let Ok(mut tx) = tx {
        match sqlx::query!("insert into hits (page, hit_key) values (?, ?)", slug, key,)
            .execute(&mut *tx)
            .await
        {
            Ok(_) => tx.commit().await.unwrap_or_default(),
            _ => { /* whatevs, fine */ }
        }
    }

    let hits = all_hits_helper(&db, slug).await;
    format!("{hits}")
}

async fn get_period_hits(
    State(db): State<SqlitePool>,
    Path(slug): Path<String>,
    Path(period): Path<String>,
) -> String {
    let now = chrono::Utc::now();
    let slug = &slug;

    let when = match period.as_str() {
        "day" => {
            let then = now - chrono::Duration::try_hours(24).unwrap();
            then.to_rfc3339()
        }
        "week" => {
            let then = now - chrono::Duration::try_days(7).unwrap();
            then.to_rfc3339()
        }
        _ => {
            let then = now - chrono::Duration::try_days(365_242).unwrap(); // 1000 years
            then.to_rfc3339()
        }
    };

    let hits = sqlx::query_scalar!(
        "select count(*) from hits where page = ? and viewed > ?",
        slug,
        when
    )
    .fetch_one(&db)
    .await
    .unwrap_or(0);

    format!("{hits}")
}

// it's easier to split this into something that handles the request parameters
// and something that does the query
async fn get_all_hits(State(db): State<SqlitePool>, Path(slug): Path<String>) -> String {
    let hits = all_hits_helper(&db, &slug).await;
    format!("{hits}")
}

//-************************************************************************
// li'l helpers
//-************************************************************************
async fn all_hits_helper(db: &SqlitePool, slug: &str) -> i32 {
    sqlx::query_scalar!("select count(*) from hits where page = ?", slug)
        .fetch_one(db)
        .await
        .unwrap_or(1)
}

fn init() {
    let cli = Cli::parse();
    dotenvy::from_path_override(cli.env).expect("Could not read .env file.");
    env_logger::builder()
        .format(|buf, record| {
            let ts = buf.timestamp();
            writeln!(buf, "{}: {}", ts, record.args())
        })
        .init();
}

async fn db() -> SqlitePool {
    let dbfile = std::env::var("DATABASE_FILE").unwrap();
    let opts = SqliteConnectOptions::new()
        .foreign_keys(true)
        .create_if_missing(true)
        .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
        .filename(&dbfile)
        .optimize_on_close(true, None);
    let pool = SqlitePoolOptions::new()
        .connect_with(opts.clone())
        .await
        .unwrap();

    sqlx::migrate!().run(&pool).await.unwrap();

    let count = sqlx::query_scalar!("select count(*) from hits")
        .fetch_one(&pool)
        .await
        .expect("could not get hit count from DB");
    log::info!("Connected to DB, found {count} total hits.");

    pool
}

async fn mklistener() -> TcpListener {
    let ip =
        std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment");
    let ip: Ipv4Addr = ip
        .parse()
        .unwrap_or_else(|_| panic!("Could not parse {ip} as an IP address"));
    let port: u16 = std::env::var("LISTENING_PORT")
        .and_then(|p| p.parse().map_err(|_| VarError::NotPresent))
        .unwrap_or_else(|_| {
            panic!("Could not find LISTENING_PORT in env or parse if present");
        });
    let addr = SocketAddr::from((ip, port));
    TcpListener::bind(&addr).await.unwrap()
}

async fn shutdown_signal() {
    use tokio::signal;
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {log::info!("shutting down")},
        _ = terminate => {},
    }
}

fn shasum(input: &[u8]) -> Vec<u8> {
    let mut context = Context::new(&SHA256);
    context.update(input);
    context.finish().as_ref().to_vec()
}