Compare commits

...

2 commits

13 changed files with 976 additions and 288 deletions

2
.gitignore vendored
View file

@ -1,2 +1,4 @@
/target /target
.env .env
*.db
*.db-*

982
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -17,10 +17,9 @@ passwords = { version = "3", default-features = false }
rand = { version = "0.8", default-features = false, features = ["getrandom"] } rand = { version = "0.8", default-features = false, features = ["getrandom"] }
serde = { version = "1", default-features = false, features = ["derive"] } serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", default-features = false } serde_json = { version = "1", default-features = false }
sqlx = { version = "0.7", default-features = false, features = ["runtime-tokio", "sqlite", "tls-none", "migrate", "macros"] }
thiserror = { version = "1" } thiserror = { version = "1" }
time = { version = "0.3", default-features = false } tokio = { version = "1", default-features = false, features = ["rt-multi-thread", "signal"] }
tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] }
tower-http = { version = "0.5", default-features = false, features = ["fs"] } tower-http = { version = "0.5", default-features = false, features = ["fs"] }
tower-sessions = { version = "0.10", default-features = false, features = ["axum-core", "memory-store"] }
unicode-segmentation = { version = "1", default-features = false } unicode-segmentation = { version = "1", default-features = false }
ureq = { version = "2", default-features = false, features = ["json", "tls"] } ureq = { version = "2", default-features = false, features = ["json", "tls"] }

View file

@ -1,11 +1,10 @@
use std::ops::RangeInclusive; use std::ops::RangeInclusive;
use axum::{ use axum::{
extract::{Form, Path}, extract::{Form, Path, State},
response::{IntoResponse, Redirect}, response::IntoResponse,
}; };
use rand::random; use sqlx::SqlitePool;
use tower_sessions::Session;
use unicode_segmentation::UnicodeSegmentation; use unicode_segmentation::UnicodeSegmentation;
use super::{util, CreateUserError, CreateUserErrorKind, SignupForm}; use super::{util, CreateUserError, CreateUserErrorKind, SignupForm};
@ -21,94 +20,59 @@ const EMAIL_LEN: RangeInclusive<usize> = 4..=50;
const CHECKOUT_TIMEOUT: i64 = 12 * 3600; const CHECKOUT_TIMEOUT: i64 = 12 * 3600;
lazy_static! { lazy_static! {
static ref SIGNUP_KEY: String = format!("meow-{}", random::<u128>()); static ref ADMIN_TOKEN: String = std::env::var("ADMIN_TOKEN").unwrap();
static ref FORGEJO_URL: String = std::env::var("FORGEJO_URL").unwrap();
static ref STRIPE_TOKEN: String = std::env::var("STRIPE_TOKEN").unwrap();
static ref ANNUAL_LINK: String = std::env::var("ANNUAL_LINK").unwrap();
static ref MONTHLY_LINK: String = std::env::var("MONTHLY_LINK").unwrap();
} }
/// Displays the signup form. /// Displays the signup form.
pub async fn get_signup() -> impl IntoResponse { pub async fn get_signup(_db: State<SqlitePool>) -> impl IntoResponse {
SignupPage { SignupPage {
monthly_link: Some((*MONTHLY_LINK).to_string()),
..Default::default() ..Default::default()
} }
} }
/// Receives the form with the user signup fields filled out. /// Receives the form with the user signup fields filled out.
pub async fn post_signup( pub async fn post_signup(
session: Session, db: State<SqlitePool>,
Form(form): Form<SignupForm>, Form(form): Form<SignupForm>,
) -> Result<impl IntoResponse, CreateUserError> { ) -> Result<impl IntoResponse, CreateUserError> {
let user = validate_signup(&form).await?; let user = validate_signup(&form).await?;
match session.insert(&SIGNUP_KEY, user).await { if create_user(&user) {
Ok(_) => {} log::info!("Created user {user:?}");
Err(e) => { Ok(SignupSuccessPage(user))
log::error!( } else {
"Could not insert validated user form into session, got {}", Err(CreateUserError(CreateUserErrorKind::UnknownEorr))
e
);
return Err(CreateUserErrorKind::UnknownEorr.into());
}
}
match session.save().await {
// TODO: pass in as env var/into a state object that the handlers can read from
Ok(_) => Ok(Redirect::to(
"https://buy.stripe.com/test_eVa6rrb7ygjNbwk000",
)),
Err(e) => {
log::error!("Could not save session, got {}", e);
Err(CreateUserErrorKind::UnknownEorr.into())
}
} }
} }
/// Redirected from Stripe with the receipt of payment. /// Redirected from Stripe with the receipt of payment.
pub async fn payment_success(session: Session, receipt: Option<Path<String>>) -> impl IntoResponse { pub async fn payment_success(
session.load().await.unwrap_or_else(|e| { db: State<SqlitePool>,
log::error!("Could not load the session, got {}", e); receipt: Option<Path<String>>,
}); ) -> impl IntoResponse {
log::debug!("loaded the session");
let user = if let Some(user) = session.get::<User>(&SIGNUP_KEY).await.unwrap_or(None) {
user
} else {
log::warn!("Could not find user in session; got receipt {:?}", receipt);
return CreateUserError(CreateUserErrorKind::NoFormFound).into_response();
};
let receipt = if let Some(Path(receipt)) = receipt { let receipt = if let Some(Path(receipt)) = receipt {
receipt receipt
} else { } else {
log::info!("Got {:?} from the session, but no receipt.", &user);
return CreateUserError(CreateUserErrorKind::BadPayment).into_response(); return CreateUserError(CreateUserErrorKind::BadPayment).into_response();
}; };
if confirm_payment(&receipt) { UserFormPage {
log::info!("Confirmed payment from {}", &receipt); receipt,
} else { ..Default::default()
return CreateUserError(CreateUserErrorKind::BadPayment).into_response();
} }
.into_response()
if create_user(&user) {
log::info!("Created user {user:?}");
} else {
return CreateUserError(CreateUserErrorKind::UnknownEorr).into_response();
}
// TODO: store the receipt into a durable store to prevent re-use after creating
// an account
session.delete().await.unwrap_or_else(|e| {
log::error!("Got error deleting {:?} from session, got {}", &user, e);
});
log::info!("Added {:?}", &user);
SignupSuccessPage(user).into_response()
} }
//-************************************************************************ //-************************************************************************
// helpers // helpers
//-************************************************************************ //-************************************************************************
fn create_user(user: &User) -> bool { fn create_user(user: &User) -> bool {
let token = std::env::var("ADMIN_TOKEN").expect("Could not find $ADMIN_TOKEN in environment."); let token = &*ADMIN_TOKEN;
let url = std::env::var("ADD_USER_ENDPOINT") let url = &*FORGEJO_URL;
.expect("Could not find $ADD_USER_ENDPOINT in environment");
let auth_header = format!("token {token}"); let auth_header = format!("token {token}");
let user: ForgejoUser = user.into(); let user: ForgejoUser = user.into();
let resp = ureq::post(&format!("{url}/api/v1/admin/users")) let resp = ureq::post(&format!("{url}/api/v1/admin/users"))
@ -127,13 +91,13 @@ fn create_user(user: &User) -> bool {
} }
fn confirm_payment(stripe_checkout_session_id: &str) -> bool { fn confirm_payment(stripe_checkout_session_id: &str) -> bool {
let token = std::env::var("STRIPE_TOKEN").expect("Could not find $STRIPE_TOKEN in environment"); let token = &*STRIPE_TOKEN;
let url = format!("https://api.stripe.com/v1/checkout/sessions/{stripe_checkout_session_id}"); let url = format!("https://api.stripe.com/v1/checkout/sessions/{stripe_checkout_session_id}");
let json: serde_json::Value = ureq::get(&url) let json: serde_json::Value = ureq::get(&url)
.set("Authorization", &format!("Bearer {token}")) .set("Authorization", &format!("Bearer {token}"))
.call() .call()
.map_err(|e| { .map_err(|e| {
log::error!("Error confirming payment from Stripe, got {}", e); log::error!("Error confirming payment from Stripe, got {e}");
std::io::Error::new(std::io::ErrorKind::Other, e) std::io::Error::new(std::io::ErrorKind::Other, e)
}) })
.and_then(|resp| resp.into_json()) .and_then(|resp| resp.into_json())
@ -156,6 +120,13 @@ async fn validate_signup(form: &SignupForm) -> Result<User, CreateUserError> {
let username = form.username.trim(); let username = form.username.trim();
let password = form.password.trim(); let password = form.password.trim();
let verify = form.pw_verify.trim(); let verify = form.pw_verify.trim();
let receipt = form.receipt.trim();
if confirm_payment(receipt) {
log::info!("Confirmed payment from {receipt}");
} else {
return Err(CreateUserError(CreateUserErrorKind::BadPayment));
}
let name_len = username.graphemes(true).size_hint().1.unwrap_or(0); let name_len = username.graphemes(true).size_hint().1.unwrap_or(0);
// we are not ascii exclusivists around here // we are not ascii exclusivists around here

View file

@ -41,8 +41,6 @@ pub enum CreateUserErrorKind {
BadEmail, BadEmail,
#[error(desc = "We could not verify your payment")] #[error(desc = "We could not verify your payment")]
BadPayment, BadPayment,
#[error(desc = "We couldn't retrieve your info from this browser session")]
NoFormFound,
} }
#[derive(Debug, Default, Deserialize, PartialEq, Eq)] #[derive(Debug, Default, Deserialize, PartialEq, Eq)]
@ -53,4 +51,5 @@ pub struct SignupForm {
pub email: String, pub email: String,
pub password: String, pub password: String,
pub pw_verify: String, pub pw_verify: String,
pub receipt: String,
} }

View file

@ -8,9 +8,9 @@ use axum::{
routing::{get, MethodRouter}, routing::{get, MethodRouter},
Router, Router,
}; };
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_http::services::ServeDir; use tower_http::services::ServeDir;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
#[macro_use] #[macro_use]
extern crate justerror; extern crate justerror;
@ -28,15 +28,11 @@ async fn main() {
init(); init();
// for javascript and css // for javascript and css
// TODO: figure out how to intern these contents
let assets_dir = std::env::current_dir().unwrap().join("assets"); let assets_dir = std::env::current_dir().unwrap().join("assets");
let assets_svc = ServeDir::new(assets_dir.as_path()); let assets_svc = ServeDir::new(assets_dir.as_path());
// just for signups let pool = db().await;
let session_store = MemoryStore::default();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_same_site(tower_sessions::cookie::SameSite::Lax)
.with_expiry(Expiry::OnInactivity(time::Duration::hours(2)));
// the core application, defining the routes and handlers // the core application, defining the routes and handlers
let app = Router::new() let app = Router::new()
@ -44,10 +40,15 @@ async fn main() {
.stripped_clone("/signup/", get(get_signup).post(post_signup)) .stripped_clone("/signup/", get(get_signup).post(post_signup))
.stripped_clone("/payment_success/", get(payment_success)) .stripped_clone("/payment_success/", get(payment_success))
.route("/payment_success/:receipt", get(payment_success)) .route("/payment_success/:receipt", get(payment_success))
.layer(session_layer) .with_state(pool.clone())
.into_make_service(); .into_make_service();
let listener = mklistener().await; let listener = mklistener().await;
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
pool.close().await;
} }
//-************************************************************************ //-************************************************************************
@ -63,6 +64,16 @@ fn init() {
.init(); .init();
} }
async fn db() -> SqlitePool {
//let dbfile = std::env::var("DATABASE_URL").unwrap();
let opts = SqliteConnectOptions::new()
.foreign_keys(true)
.create_if_missing(true)
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.optimize_on_close(true, None);
SqlitePoolOptions::new().connect_with(opts).await.unwrap()
}
async fn mklistener() -> TcpListener { async fn mklistener() -> TcpListener {
let ip = let ip =
std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment"); std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment");
@ -96,3 +107,28 @@ where
.route(path.trim_end_matches('/'), method_router) .route(path.trim_end_matches('/'), method_router)
} }
} }
async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {log::info!("shutting down")},
_ = terminate => {},
}
}

View file

@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use crate::user::User; use crate::user::User;
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
#[template(path = "signup.html")] #[template(path = "user_form.html")]
pub struct SignupPage { pub struct UserFormPage {
pub username: String, pub username: String,
pub displayname: Option<String>, pub displayname: Option<String>,
pub email: Option<String>, pub email: Option<String>,
@ -21,3 +21,11 @@ pub struct SignupSuccessPage(pub User);
#[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)] #[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)]
#[template(path = "signup_error.html")] #[template(path = "signup_error.html")]
pub struct SignupErrorPage(pub String); pub struct SignupErrorPage(pub String);
#[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)]
#[template(path = "signup.html")]
pub struct SignupPage {
pub annual_link: Option<String>,
pub monthly_link: Option<String>,
pub invitation: Option<String>,
}

View file

@ -16,7 +16,6 @@
{% block header %}{% endblock %} {% block header %}{% endblock %}
</div> </div>
<div id="content"> <div id="content">
<hr />
{% block content %}{% endblock %} {% block content %}{% endblock %}
</div> </div>
<div id="footer"> <div id="footer">

View file

@ -1,26 +0,0 @@
{% extends "base.html" %}
{% block title %}Welcome to What 2 Watch, Bish{% endblock %}
{% block content %}
<h1>Welcome to What 2 Watch</h1>
{% match user %}
{% when Some with (usr) %}
<p>
Hello, {{ usr.username }}! It's nice to see you. <a href="watches">Let's get watchin'!</a>
</p>
</br>
<p>
<form action="/logout" enctype="application/x-www-form-urlencoded" method="post">
<input type="submit" value="sign out?">
</form>
</p>
{% when None %}
<p>
Heya, why don't you <a href="/login">log in</a> or <a href="/signup">sign up</a>?
</p>
{% endmatch %}
{% endblock %}

View file

@ -1,26 +1,35 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block title %}Welcome, friend, to git.kittenclause.com{% endblock %} {% block title %}Welcome to the Kitten Collective!{% endblock %}
{% block header %} {% endblock %}
{% block content %} {% block content %}
<p> <!-- Show the monthly link if we have it -->
<form action="/signup" enctype="application/x-www-form-urlencoded" method="post">
<input type="hidden" value="{{ self.receipt }}" name="receipt"> {% match monthly_link %}
<label for="username">Username</label> {% when Some with (link) %}
<input type="text" name="username" id="username" minlength="1" maxlength="20" required></br> <div>
<label for="displayname">Displayname (optional)</label> <p><a href="{{ link }}">Just $3/month!</a></p>
<input type="text" name="displayname" id="displayname"></br> </div>
<label for="email">Email</label> {% else %}
<input type="text" name="email"></br> {% endmatch %}
<label for="password">Password</label>
<input type="password" name="password" id="password" required></br> <!-- Show the annual link if we have it -->
<label for="confirm_password">Confirm Password</label> {% match annual_link %}
<input type="password" name="pw_verify" id="pw_verify" required></br> {% when Some with (link) %}
<input type="submit" value="Signup"> <div>
</form> <p><a href="{{ link }}">Just $30/year!</a></p>
</p> </div>
{% else %}
{% endmatch %}
<!-- the invitation is also a URL -->
{% match invitation %}
{% when Some with (link) %}
<div>
<p><a href="{{ link }}">Free, limited account for collaborating with other kittens</a></p>
</div>
{% else %}
{% endmatch %}
{% endblock %} {% endblock %}

View file

@ -1,6 +1,6 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block title %}Dang, Bish{% endblock %} {% block title %}Dang!{% endblock %}
{% block content %} {% block content %}
{% block header %}{% endblock %} {% block header %}{% endblock %}

View file

@ -1,6 +1,6 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block title %}Thanks for Signing Up for What 2 Watch, Bish{% endblock %} {% block title %}Thanks for Signing up for the Kitten Collective!{% endblock %}
{% block content %} {% block content %}
{% block header %}{% endblock %} {% block header %}{% endblock %}
@ -13,7 +13,8 @@
</p> </p>
</div> </div>
<p>Now, head on over to <a href="/login?redirect_to=%2F{{ self.0.username|escape }}">the login page</a> and git going! <p>Now, head on over to <a href="/user/login?redirect_to=%2F{{ self.0.username|escape }}">the login page</a> and git
going!
</p> </p>
{% endblock %} {% endblock %}

26
templates/user_form.html Normal file
View file

@ -0,0 +1,26 @@
{% extends "base.html" %}
{% block title %}Welcome, friend, to git.kittenclause.com{% endblock %}
{% block header %} {% endblock %}
{% block content %}
<p>
<form action="/signup" enctype="application/x-www-form-urlencoded" method="post">
<input type="hidden" value="{{ self.receipt }}" name="receipt">
<label for="username">Username</label>
<input type="text" name="username" id="username" minlength="1" maxlength="20" required></br>
<label for="displayname">Displayname (optional)</label>
<input type="text" name="displayname" id="displayname"></br>
<label for="email">Email</label>
<input type="text" name="email"></br>
<label for="password">Password</label>
<input type="password" name="password" id="password" required></br>
<label for="confirm_password">Confirm Password</label>
<input type="password" name="pw_verify" id="pw_verify" required></br>
<input type="submit" value="Signup">
</form>
</p>
{% endblock %}