diff --git a/.gitignore b/.gitignore index fedaa2b..1b3cb3c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target .env +*.db +*.db-* diff --git a/Cargo.lock b/Cargo.lock index ba28629..cffec90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -330,15 +330,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "deranged" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" -dependencies = [ - "powerfmt", -] - [[package]] name = "digest" version = "0.10.7" @@ -939,12 +930,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "num-conv" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" - [[package]] name = "num-integer" version = "0.1.46" @@ -1112,12 +1097,6 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" -[[package]] -name = "powerfmt" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" - [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1158,7 +1137,6 @@ dependencies = [ "serde_json", "sqlx", "thiserror", - "time", "tokio", "tower-http", "unicode-segmentation", @@ -1763,24 +1741,6 @@ dependencies = [ "syn 2.0.52", ] -[[package]] -name = "time" -version = "0.3.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" -dependencies = [ - "deranged", - "num-conv", - "powerfmt", - "time-core", -] - -[[package]] -name = "time-core" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" - [[package]] name = "tinyvec" version = "1.6.0" diff --git a/Cargo.toml b/Cargo.toml index c2820b3..ed7dc19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,6 @@ serde = { version = "1", default-features = false, features = ["derive"] } serde_json = { version = "1", default-features = false } sqlx = { version = "0.7", default-features = false, features = ["runtime-tokio", "sqlite", "tls-none", "migrate", "macros"] } thiserror = { version = "1" } -time = { version = "0.3", default-features = false } tokio = { version = "1", default-features = false, features = ["rt-multi-thread", "signal"] } tower-http = { version = "0.5", default-features = false, features = ["fs"] } unicode-segmentation = { version = "1", default-features = false } diff --git a/src/handlers/handlers.rs b/src/handlers/handlers.rs index cb6d2a9..91eab4a 100644 --- a/src/handlers/handlers.rs +++ b/src/handlers/handlers.rs @@ -1,11 +1,10 @@ use std::ops::RangeInclusive; use axum::{ - extract::{Form, Path}, - response::{IntoResponse, Redirect}, + extract::{Form, Path, State}, + response::IntoResponse, }; -use rand::random; -use tower_sessions::Session; +use sqlx::SqlitePool; use unicode_segmentation::UnicodeSegmentation; use super::{util, CreateUserError, CreateUserErrorKind, SignupForm}; @@ -21,94 +20,59 @@ const EMAIL_LEN: RangeInclusive = 4..=50; const CHECKOUT_TIMEOUT: i64 = 12 * 3600; lazy_static! { - static ref SIGNUP_KEY: String = format!("meow-{}", random::()); + static ref ADMIN_TOKEN: String = std::env::var("ADMIN_TOKEN").unwrap(); + static ref FORGEJO_URL: String = std::env::var("FORGEJO_URL").unwrap(); + static ref STRIPE_TOKEN: String = std::env::var("STRIPE_TOKEN").unwrap(); + static ref ANNUAL_LINK: String = std::env::var("ANNUAL_LINK").unwrap(); + static ref MONTHLY_LINK: String = std::env::var("MONTHLY_LINK").unwrap(); } /// Displays the signup form. -pub async fn get_signup() -> impl IntoResponse { +pub async fn get_signup(_db: State) -> impl IntoResponse { SignupPage { + monthly_link: Some((*MONTHLY_LINK).to_string()), ..Default::default() } } /// Receives the form with the user signup fields filled out. pub async fn post_signup( - session: Session, + db: State, Form(form): Form, ) -> Result { let user = validate_signup(&form).await?; - match session.insert(&SIGNUP_KEY, user).await { - Ok(_) => {} - Err(e) => { - log::error!( - "Could not insert validated user form into session, got {}", - e - ); - return Err(CreateUserErrorKind::UnknownEorr.into()); - } - } - - match session.save().await { - // TODO: pass in as env var/into a state object that the handlers can read from - Ok(_) => Ok(Redirect::to( - "https://buy.stripe.com/test_eVa6rrb7ygjNbwk000", - )), - Err(e) => { - log::error!("Could not save session, got {}", e); - Err(CreateUserErrorKind::UnknownEorr.into()) - } + if create_user(&user) { + log::info!("Created user {user:?}"); + Ok(SignupSuccessPage(user)) + } else { + Err(CreateUserError(CreateUserErrorKind::UnknownEorr)) } } /// Redirected from Stripe with the receipt of payment. -pub async fn payment_success(session: Session, receipt: Option>) -> impl IntoResponse { - session.load().await.unwrap_or_else(|e| { - log::error!("Could not load the session, got {}", e); - }); - log::debug!("loaded the session"); - let user = if let Some(user) = session.get::(&SIGNUP_KEY).await.unwrap_or(None) { - user - } else { - log::warn!("Could not find user in session; got receipt {:?}", receipt); - return CreateUserError(CreateUserErrorKind::NoFormFound).into_response(); - }; - +pub async fn payment_success( + db: State, + receipt: Option>, +) -> impl IntoResponse { let receipt = if let Some(Path(receipt)) = receipt { receipt } else { - log::info!("Got {:?} from the session, but no receipt.", &user); return CreateUserError(CreateUserErrorKind::BadPayment).into_response(); }; - if confirm_payment(&receipt) { - log::info!("Confirmed payment from {}", &receipt); - } else { - return CreateUserError(CreateUserErrorKind::BadPayment).into_response(); + UserFormPage { + receipt, + ..Default::default() } - - if create_user(&user) { - log::info!("Created user {user:?}"); - } else { - return CreateUserError(CreateUserErrorKind::UnknownEorr).into_response(); - } - // TODO: store the receipt into a durable store to prevent re-use after creating - // an account - - session.delete().await.unwrap_or_else(|e| { - log::error!("Got error deleting {:?} from session, got {}", &user, e); - }); - - log::info!("Added {:?}", &user); - SignupSuccessPage(user).into_response() + .into_response() } //-************************************************************************ // helpers //-************************************************************************ fn create_user(user: &User) -> bool { - let token = std::env::var("ADMIN_TOKEN").expect("Could not find $ADMIN_TOKEN in environment."); - let url = std::env::var("ADD_USER_ENDPOINT") - .expect("Could not find $ADD_USER_ENDPOINT in environment"); + let token = &*ADMIN_TOKEN; + let url = &*FORGEJO_URL; let auth_header = format!("token {token}"); let user: ForgejoUser = user.into(); let resp = ureq::post(&format!("{url}/api/v1/admin/users")) @@ -127,13 +91,13 @@ fn create_user(user: &User) -> bool { } fn confirm_payment(stripe_checkout_session_id: &str) -> bool { - let token = std::env::var("STRIPE_TOKEN").expect("Could not find $STRIPE_TOKEN in environment"); + let token = &*STRIPE_TOKEN; let url = format!("https://api.stripe.com/v1/checkout/sessions/{stripe_checkout_session_id}"); let json: serde_json::Value = ureq::get(&url) .set("Authorization", &format!("Bearer {token}")) .call() .map_err(|e| { - log::error!("Error confirming payment from Stripe, got {}", e); + log::error!("Error confirming payment from Stripe, got {e}"); std::io::Error::new(std::io::ErrorKind::Other, e) }) .and_then(|resp| resp.into_json()) @@ -156,6 +120,13 @@ async fn validate_signup(form: &SignupForm) -> Result { let username = form.username.trim(); let password = form.password.trim(); let verify = form.pw_verify.trim(); + let receipt = form.receipt.trim(); + + if confirm_payment(receipt) { + log::info!("Confirmed payment from {receipt}"); + } else { + return Err(CreateUserError(CreateUserErrorKind::BadPayment)); + } let name_len = username.graphemes(true).size_hint().1.unwrap_or(0); // we are not ascii exclusivists around here diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 6752ee1..b1e6411 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -41,8 +41,6 @@ pub enum CreateUserErrorKind { BadEmail, #[error(desc = "We could not verify your payment")] BadPayment, - #[error(desc = "We couldn't retrieve your info from this browser session")] - NoFormFound, } #[derive(Debug, Default, Deserialize, PartialEq, Eq)] @@ -53,4 +51,5 @@ pub struct SignupForm { pub email: String, pub password: String, pub pw_verify: String, + pub receipt: String, } diff --git a/src/main.rs b/src/main.rs index 7021ca1..db7f4dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,9 +8,9 @@ use axum::{ routing::{get, MethodRouter}, Router, }; +use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; use tokio::net::TcpListener; use tower_http::services::ServeDir; -use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; #[macro_use] extern crate justerror; @@ -28,15 +28,11 @@ async fn main() { init(); // for javascript and css + // TODO: figure out how to intern these contents let assets_dir = std::env::current_dir().unwrap().join("assets"); let assets_svc = ServeDir::new(assets_dir.as_path()); - // just for signups - let session_store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) - .with_same_site(tower_sessions::cookie::SameSite::Lax) - .with_expiry(Expiry::OnInactivity(time::Duration::hours(2))); + let pool = db().await; // the core application, defining the routes and handlers let app = Router::new() @@ -44,10 +40,15 @@ async fn main() { .stripped_clone("/signup/", get(get_signup).post(post_signup)) .stripped_clone("/payment_success/", get(payment_success)) .route("/payment_success/:receipt", get(payment_success)) - .layer(session_layer) + .with_state(pool.clone()) .into_make_service(); let listener = mklistener().await; - axum::serve(listener, app).await.unwrap(); + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap(); + + pool.close().await; } //-************************************************************************ @@ -63,6 +64,16 @@ fn init() { .init(); } +async fn db() -> SqlitePool { + //let dbfile = std::env::var("DATABASE_URL").unwrap(); + let opts = SqliteConnectOptions::new() + .foreign_keys(true) + .create_if_missing(true) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) + .optimize_on_close(true, None); + SqlitePoolOptions::new().connect_with(opts).await.unwrap() +} + async fn mklistener() -> TcpListener { let ip = std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment"); @@ -96,3 +107,28 @@ where .route(path.trim_end_matches('/'), method_router) } } + +async fn shutdown_signal() { + use tokio::signal; + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {log::info!("shutting down")}, + _ = terminate => {}, + } +} diff --git a/src/templates.rs b/src/templates.rs index c830250..25305d7 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize}; use crate::user::User; #[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] -#[template(path = "signup.html")] -pub struct SignupPage { +#[template(path = "user_form.html")] +pub struct UserFormPage { pub username: String, pub displayname: Option, pub email: Option, @@ -21,3 +21,11 @@ pub struct SignupSuccessPage(pub User); #[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "signup_error.html")] pub struct SignupErrorPage(pub String); + +#[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)] +#[template(path = "signup.html")] +pub struct SignupPage { + pub annual_link: Option, + pub monthly_link: Option, + pub invitation: Option, +} diff --git a/templates/base.html b/templates/base.html index 49a8bc4..5807ef1 100644 --- a/templates/base.html +++ b/templates/base.html @@ -16,7 +16,6 @@ {% block header %}{% endblock %}
-
{% block content %}{% endblock %}
-

Now, head on over to the login page and git going! +

Now, head on over to the login page and git + going!

{% endblock %} diff --git a/templates/user_form.html b/templates/user_form.html new file mode 100644 index 0000000..2a11cae --- /dev/null +++ b/templates/user_form.html @@ -0,0 +1,26 @@ +{% extends "base.html" %} + +{% block title %}Welcome, friend, to git.kittenclause.com{% endblock %} + +{% block header %} {% endblock %} + +{% block content %} + +

+

+ + +
+ +
+ +
+ +
+ +
+ +
+

+ +{% endblock %}