Compare commits

..

No commits in common. "5aea86d8af3a31f3d443a3394a6e60d0f87ded77" and "8e25651cfab2db51fda408abf9b56abfc06ccc31" have entirely different histories.

15 changed files with 286 additions and 1026 deletions

2
.gitignore vendored
View file

@ -1,4 +1,2 @@
/target /target
.env .env
*.db
*.db-*

978
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -17,9 +17,10 @@ passwords = { version = "3", default-features = false }
rand = { version = "0.8", default-features = false, features = ["getrandom"] } rand = { version = "0.8", default-features = false, features = ["getrandom"] }
serde = { version = "1", default-features = false, features = ["derive"] } serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", default-features = false } serde_json = { version = "1", default-features = false }
sqlx = { version = "0.7", default-features = false, features = ["runtime-tokio", "sqlite", "tls-none", "migrate", "macros"] }
thiserror = { version = "1" } thiserror = { version = "1" }
tokio = { version = "1", default-features = false, features = ["rt-multi-thread", "signal"] } time = { version = "0.3", default-features = false }
tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] }
tower-http = { version = "0.5", default-features = false, features = ["fs"] } tower-http = { version = "0.5", default-features = false, features = ["fs"] }
tower-sessions = { version = "0.10", default-features = false, features = ["axum-core", "memory-store"] }
unicode-segmentation = { version = "1", default-features = false } unicode-segmentation = { version = "1", default-features = false }
ureq = { version = "2", default-features = false, features = ["json", "tls"] } ureq = { version = "2", default-features = false, features = ["json", "tls"] }

View file

@ -1,10 +0,0 @@
drop table if exists customers;
drop index if exists customers_username_dex;
drop index if exists customers_email_dex;
drop index if exists customers_invitation_dex;
drop trigger if exists update_last_updated_customers;
drop table if exists invitations;
drop index if exists invitations_owner_dex;
drop trigger if exists update_updated_at_invitations;

View file

@ -1,40 +0,0 @@
create table if not exists customers (
id integer primary key,
username text not null unique,
receipt text not null unique,
billing_email text,
invitation id,
created_at int not null default (unixepoch()),
updated_at int not null default (unixepoch()),
foreign key (invitation) references invitations (id)
);
create index if not exists customers_username_dex on customers (lower(username));
create index if not exists customers_email_dex on customers (lower(billing_email));
create index if not exists customers_invitation_dex on customers (invitation); -- does this need to be created? it's already a foreign key
create trigger if not exists update_last_updated_customers
after update on customers
when OLD.updated_at = NEW.updated_at or OLD.updated_at is null
BEGIN
update customers set updated_at = (select unixepoch()) where id=NEW.id;
END;
create table if not exists invitations (
id integer primary key,
owner integer not null,
remaining integer not null default 1,
expires_at integer,
created_at integer not null default (unixepoch()),
updated_at integer not null default (unixepoch()),
foreign key (owner) references customers (id)
);
create index if not exists invitations_owner_dex on invitations (owner);
create trigger if not exists update_updated_at_invitations
after update on invitations
when OLD.updated_at = NEW.updated_at or OLD.updated_at is null
BEGIN
update invitations set updated_at = (select unixepoch()) where id=NEW.id;
END;

View file

@ -1,10 +1,11 @@
use std::ops::RangeInclusive; use std::ops::RangeInclusive;
use axum::{ use axum::{
extract::{Form, Path, State}, extract::{Form, Path},
response::IntoResponse, response::{IntoResponse, Redirect},
}; };
use sqlx::SqlitePool; use rand::random;
use tower_sessions::Session;
use unicode_segmentation::UnicodeSegmentation; use unicode_segmentation::UnicodeSegmentation;
use super::{util, CreateUserError, CreateUserErrorKind, SignupForm}; use super::{util, CreateUserError, CreateUserErrorKind, SignupForm};
@ -20,59 +21,94 @@ const EMAIL_LEN: RangeInclusive<usize> = 4..=50;
const CHECKOUT_TIMEOUT: i64 = 12 * 3600; const CHECKOUT_TIMEOUT: i64 = 12 * 3600;
lazy_static! { lazy_static! {
static ref ADMIN_TOKEN: String = std::env::var("ADMIN_TOKEN").unwrap(); static ref SIGNUP_KEY: String = format!("meow-{}", random::<u128>());
static ref FORGEJO_URL: String = std::env::var("FORGEJO_URL").unwrap();
static ref STRIPE_TOKEN: String = std::env::var("STRIPE_TOKEN").unwrap();
static ref ANNUAL_LINK: String = std::env::var("ANNUAL_LINK").unwrap();
static ref MONTHLY_LINK: String = std::env::var("MONTHLY_LINK").unwrap();
} }
/// Displays the signup form. /// Displays the signup form.
pub async fn get_signup(_db: State<SqlitePool>) -> impl IntoResponse { pub async fn get_signup() -> impl IntoResponse {
SignupPage { SignupPage {
monthly_link: Some((*MONTHLY_LINK).to_string()),
..Default::default() ..Default::default()
} }
} }
/// Receives the form with the user signup fields filled out. /// Receives the form with the user signup fields filled out.
pub async fn post_signup( pub async fn post_signup(
db: State<SqlitePool>, session: Session,
Form(form): Form<SignupForm>, Form(form): Form<SignupForm>,
) -> Result<impl IntoResponse, CreateUserError> { ) -> Result<impl IntoResponse, CreateUserError> {
let user = validate_signup(&form).await?; let user = validate_signup(&form).await?;
if create_user(&user) { match session.insert(&SIGNUP_KEY, user).await {
log::info!("Created user {user:?}"); Ok(_) => {}
Ok(SignupSuccessPage(user)) Err(e) => {
} else { log::error!(
Err(CreateUserError(CreateUserErrorKind::UnknownEorr)) "Could not insert validated user form into session, got {}",
e
);
return Err(CreateUserErrorKind::UnknownEorr.into());
}
}
match session.save().await {
// TODO: pass in as env var/into a state object that the handlers can read from
Ok(_) => Ok(Redirect::to(
"https://buy.stripe.com/test_eVa6rrb7ygjNbwk000",
)),
Err(e) => {
log::error!("Could not save session, got {}", e);
Err(CreateUserErrorKind::UnknownEorr.into())
}
} }
} }
/// Redirected from Stripe with the receipt of payment. /// Redirected from Stripe with the receipt of payment.
pub async fn payment_success( pub async fn payment_success(session: Session, receipt: Option<Path<String>>) -> impl IntoResponse {
db: State<SqlitePool>, session.load().await.unwrap_or_else(|e| {
receipt: Option<Path<String>>, log::error!("Could not load the session, got {}", e);
) -> impl IntoResponse { });
log::debug!("loaded the session");
let user = if let Some(user) = session.get::<User>(&SIGNUP_KEY).await.unwrap_or(None) {
user
} else {
log::warn!("Could not find user in session; got receipt {:?}", receipt);
return CreateUserError(CreateUserErrorKind::NoFormFound).into_response();
};
let receipt = if let Some(Path(receipt)) = receipt { let receipt = if let Some(Path(receipt)) = receipt {
receipt receipt
} else { } else {
log::info!("Got {:?} from the session, but no receipt.", &user);
return CreateUserError(CreateUserErrorKind::BadPayment).into_response(); return CreateUserError(CreateUserErrorKind::BadPayment).into_response();
}; };
UserFormPage { if confirm_payment(&receipt) {
receipt, log::info!("Confirmed payment from {}", &receipt);
..Default::default() } else {
return CreateUserError(CreateUserErrorKind::BadPayment).into_response();
} }
.into_response()
if create_user(&user) {
log::info!("Created user {user:?}");
} else {
return CreateUserError(CreateUserErrorKind::UnknownEorr).into_response();
}
// TODO: store the receipt into a durable store to prevent re-use after creating
// an account
session.delete().await.unwrap_or_else(|e| {
log::error!("Got error deleting {:?} from session, got {}", &user, e);
});
log::info!("Added {:?}", &user);
SignupSuccessPage(user).into_response()
} }
//-************************************************************************ //-************************************************************************
// helpers // helpers
//-************************************************************************ //-************************************************************************
fn create_user(user: &User) -> bool { fn create_user(user: &User) -> bool {
let token = &*ADMIN_TOKEN; let token = std::env::var("ADMIN_TOKEN").expect("Could not find $ADMIN_TOKEN in environment.");
let url = &*FORGEJO_URL; let url = std::env::var("ADD_USER_ENDPOINT")
.expect("Could not find $ADD_USER_ENDPOINT in environment");
let auth_header = format!("token {token}"); let auth_header = format!("token {token}");
let user: ForgejoUser = user.into(); let user: ForgejoUser = user.into();
let resp = ureq::post(&format!("{url}/api/v1/admin/users")) let resp = ureq::post(&format!("{url}/api/v1/admin/users"))
@ -91,13 +127,13 @@ fn create_user(user: &User) -> bool {
} }
fn confirm_payment(stripe_checkout_session_id: &str) -> bool { fn confirm_payment(stripe_checkout_session_id: &str) -> bool {
let token = &*STRIPE_TOKEN; let token = std::env::var("STRIPE_TOKEN").expect("Could not find $STRIPE_TOKEN in environment");
let url = format!("https://api.stripe.com/v1/checkout/sessions/{stripe_checkout_session_id}"); let url = format!("https://api.stripe.com/v1/checkout/sessions/{stripe_checkout_session_id}");
let json: serde_json::Value = ureq::get(&url) let json: serde_json::Value = ureq::get(&url)
.set("Authorization", &format!("Bearer {token}")) .set("Authorization", &format!("Bearer {token}"))
.call() .call()
.map_err(|e| { .map_err(|e| {
log::error!("Error confirming payment from Stripe, got {e}"); log::error!("Error confirming payment from Stripe, got {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e) std::io::Error::new(std::io::ErrorKind::Other, e)
}) })
.and_then(|resp| resp.into_json()) .and_then(|resp| resp.into_json())
@ -120,13 +156,6 @@ async fn validate_signup(form: &SignupForm) -> Result<User, CreateUserError> {
let username = form.username.trim(); let username = form.username.trim();
let password = form.password.trim(); let password = form.password.trim();
let verify = form.pw_verify.trim(); let verify = form.pw_verify.trim();
let receipt = form.receipt.trim();
if confirm_payment(receipt) {
log::info!("Confirmed payment from {receipt}");
} else {
return Err(CreateUserError(CreateUserErrorKind::BadPayment));
}
let name_len = username.graphemes(true).size_hint().1.unwrap_or(0); let name_len = username.graphemes(true).size_hint().1.unwrap_or(0);
// we are not ascii exclusivists around here // we are not ascii exclusivists around here

View file

@ -41,6 +41,8 @@ pub enum CreateUserErrorKind {
BadEmail, BadEmail,
#[error(desc = "We could not verify your payment")] #[error(desc = "We could not verify your payment")]
BadPayment, BadPayment,
#[error(desc = "We couldn't retrieve your info from this browser session")]
NoFormFound,
} }
#[derive(Debug, Default, Deserialize, PartialEq, Eq)] #[derive(Debug, Default, Deserialize, PartialEq, Eq)]
@ -51,5 +53,4 @@ pub struct SignupForm {
pub email: String, pub email: String,
pub password: String, pub password: String,
pub pw_verify: String, pub pw_verify: String,
pub receipt: String,
} }

View file

@ -8,9 +8,9 @@ use axum::{
routing::{get, MethodRouter}, routing::{get, MethodRouter},
Router, Router,
}; };
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_http::services::ServeDir; use tower_http::services::ServeDir;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
#[macro_use] #[macro_use]
extern crate justerror; extern crate justerror;
@ -28,13 +28,15 @@ async fn main() {
init(); init();
// for javascript and css // for javascript and css
// TODO: figure out how to intern these contents
let assets_dir = std::env::current_dir().unwrap().join("assets"); let assets_dir = std::env::current_dir().unwrap().join("assets");
let assets_svc = ServeDir::new(assets_dir.as_path()); let assets_svc = ServeDir::new(assets_dir.as_path());
let pool = db().await; // just for signups
let session_store = MemoryStore::default();
sqlx::migrate!().run(&pool).await.unwrap(); let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_same_site(tower_sessions::cookie::SameSite::Lax)
.with_expiry(Expiry::OnInactivity(time::Duration::hours(2)));
// the core application, defining the routes and handlers // the core application, defining the routes and handlers
let app = Router::new() let app = Router::new()
@ -42,15 +44,10 @@ async fn main() {
.stripped_clone("/signup/", get(get_signup).post(post_signup)) .stripped_clone("/signup/", get(get_signup).post(post_signup))
.stripped_clone("/payment_success/", get(payment_success)) .stripped_clone("/payment_success/", get(payment_success))
.route("/payment_success/:receipt", get(payment_success)) .route("/payment_success/:receipt", get(payment_success))
.with_state(pool.clone()) .layer(session_layer)
.into_make_service(); .into_make_service();
let listener = mklistener().await; let listener = mklistener().await;
axum::serve(listener, app) axum::serve(listener, app).await.unwrap();
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
pool.close().await;
} }
//-************************************************************************ //-************************************************************************
@ -66,16 +63,6 @@ fn init() {
.init(); .init();
} }
async fn db() -> SqlitePool {
//let dbfile = std::env::var("DATABASE_URL").unwrap();
let opts = SqliteConnectOptions::new()
.foreign_keys(true)
.create_if_missing(true)
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.optimize_on_close(true, None);
SqlitePoolOptions::new().connect_with(opts).await.unwrap()
}
async fn mklistener() -> TcpListener { async fn mklistener() -> TcpListener {
let ip = let ip =
std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment"); std::env::var("LISTENING_ADDR").expect("Could not find $LISTENING_ADDR in environment");
@ -109,28 +96,3 @@ where
.route(path.trim_end_matches('/'), method_router) .route(path.trim_end_matches('/'), method_router)
} }
} }
async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {log::info!("shutting down")},
_ = terminate => {},
}
}

View file

@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use crate::user::User; use crate::user::User;
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
#[template(path = "user_form.html")] #[template(path = "signup.html")]
pub struct UserFormPage { pub struct SignupPage {
pub username: String, pub username: String,
pub displayname: Option<String>, pub displayname: Option<String>,
pub email: Option<String>, pub email: Option<String>,
@ -21,11 +21,3 @@ pub struct SignupSuccessPage(pub User);
#[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)] #[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)]
#[template(path = "signup_error.html")] #[template(path = "signup_error.html")]
pub struct SignupErrorPage(pub String); pub struct SignupErrorPage(pub String);
#[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)]
#[template(path = "signup.html")]
pub struct SignupPage {
pub annual_link: Option<String>,
pub monthly_link: Option<String>,
pub invitation: Option<String>,
}

View file

@ -16,6 +16,7 @@
{% block header %}{% endblock %} {% block header %}{% endblock %}
</div> </div>
<div id="content"> <div id="content">
<hr />
{% block content %}{% endblock %} {% block content %}{% endblock %}
</div> </div>
<div id="footer"> <div id="footer">

26
templates/index.html Normal file
View file

@ -0,0 +1,26 @@
{% extends "base.html" %}
{% block title %}Welcome to What 2 Watch, Bish{% endblock %}
{% block content %}
<h1>Welcome to What 2 Watch</h1>
{% match user %}
{% when Some with (usr) %}
<p>
Hello, {{ usr.username }}! It's nice to see you. <a href="watches">Let's get watchin'!</a>
</p>
</br>
<p>
<form action="/logout" enctype="application/x-www-form-urlencoded" method="post">
<input type="submit" value="sign out?">
</form>
</p>
{% when None %}
<p>
Heya, why don't you <a href="/login">log in</a> or <a href="/signup">sign up</a>?
</p>
{% endmatch %}
{% endblock %}

View file

@ -1,35 +1,26 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block title %}Welcome to the Kitten Collective!{% endblock %} {% block title %}Welcome, friend, to git.kittenclause.com{% endblock %}
{% block header %} {% endblock %}
{% block content %} {% block content %}
<!-- Show the monthly link if we have it --> <p>
<form action="/signup" enctype="application/x-www-form-urlencoded" method="post">
{% match monthly_link %} <input type="hidden" value="{{ self.receipt }}" name="receipt">
{% when Some with (link) %} <label for="username">Username</label>
<div> <input type="text" name="username" id="username" minlength="1" maxlength="20" required></br>
<p><a href="{{ link }}">Just $3/month!</a></p> <label for="displayname">Displayname (optional)</label>
</div> <input type="text" name="displayname" id="displayname"></br>
{% else %} <label for="email">Email</label>
{% endmatch %} <input type="text" name="email"></br>
<label for="password">Password</label>
<!-- Show the annual link if we have it --> <input type="password" name="password" id="password" required></br>
{% match annual_link %} <label for="confirm_password">Confirm Password</label>
{% when Some with (link) %} <input type="password" name="pw_verify" id="pw_verify" required></br>
<div> <input type="submit" value="Signup">
<p><a href="{{ link }}">Just $30/year!</a></p> </form>
</div> </p>
{% else %}
{% endmatch %}
<!-- the invitation is also a URL -->
{% match invitation %}
{% when Some with (link) %}
<div>
<p><a href="{{ link }}">Free, limited account for collaborating with other kittens</a></p>
</div>
{% else %}
{% endmatch %}
{% endblock %} {% endblock %}

View file

@ -1,6 +1,6 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block title %}Dang!{% endblock %} {% block title %}Dang, Bish{% endblock %}
{% block content %} {% block content %}
{% block header %}{% endblock %} {% block header %}{% endblock %}

View file

@ -1,6 +1,6 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block title %}Thanks for Signing up for the Kitten Collective!{% endblock %} {% block title %}Thanks for Signing Up for What 2 Watch, Bish{% endblock %}
{% block content %} {% block content %}
{% block header %}{% endblock %} {% block header %}{% endblock %}
@ -13,8 +13,7 @@
</p> </p>
</div> </div>
<p>Now, head on over to <a href="/user/login?redirect_to=%2F{{ self.0.username|escape }}">the login page</a> and git <p>Now, head on over to <a href="/login?redirect_to=%2F{{ self.0.username|escape }}">the login page</a> and git going!
going!
</p> </p>
{% endblock %} {% endblock %}

View file

@ -1,26 +0,0 @@
{% extends "base.html" %}
{% block title %}Welcome, friend, to git.kittenclause.com{% endblock %}
{% block header %} {% endblock %}
{% block content %}
<p>
<form action="/signup" enctype="application/x-www-form-urlencoded" method="post">
<input type="hidden" value="{{ self.receipt }}" name="receipt">
<label for="username">Username</label>
<input type="text" name="username" id="username" minlength="1" maxlength="20" required></br>
<label for="displayname">Displayname (optional)</label>
<input type="text" name="displayname" id="displayname"></br>
<label for="email">Email</label>
<input type="text" name="email"></br>
<label for="password">Password</label>
<input type="password" name="password" id="password" required></br>
<label for="confirm_password">Confirm Password</label>
<input type="password" name="pw_verify" id="pw_verify" required></br>
<input type="submit" value="Signup">
</form>
</p>
{% endblock %}