use std::{ error::Error, fmt::{Debug, Display}, net::SocketAddr, ops::Range, }; use askama::Template; use axum::{ extract::{Form, Path}, http::StatusCode, response::{IntoResponse, Redirect, Response}, routing::get, Router, }; use serde::{Deserialize, Serialize}; use tower_sessions::{MemoryStore, Session, SessionManagerLayer}; use unicode_segmentation::UnicodeSegmentation; #[macro_use] extern crate justerror; const SIGNUP_KEY: &str = "meow"; #[derive(Default, Deserialize, Serialize)] struct Counter(usize); /// Displays the signup form. async fn get_signup() -> impl IntoResponse { SignupPage::default() } /// Receives the form with the user signup fields filled out. async fn post_signup( session: Session, Form(form): Form, ) -> Result { let username = form.username.trim(); let password = form.password.trim(); let verify = form.pw_verify.trim(); let name_len = username.graphemes(true).size_hint().1.unwrap(); // we are not ascii exclusivists around here if !(1..=20).contains(&name_len) { return Err(CreateUserErrorKind::BadUsername.into()); } if password != verify { return Err(CreateUserErrorKind::PasswordMismatch.into()); } let pwlen = password.graphemes(true).size_hint().1.unwrap_or(0); if !(4..=50).contains(&pwlen) { return Err(CreateUserErrorKind::BadPassword.into()); } // clean up the optionals let displayname = validate_optional_length( &form.displayname, 0..100, CreateUserErrorKind::BadDisplayname, )?; let email = validate_optional_length(&form.email, 5..30, CreateUserErrorKind::BadEmail)?; let user = User { username: username.to_string(), displayname, email, password: password.to_string(), pw_verify: verify.to_string(), }; session.insert(SIGNUP_KEY, user).await.unwrap(); Ok(Redirect::to( "https://buy.stripe.com/test_eVa6rrb7ygjNbwk000", )) } /// Called from Stripe with the receipt of payment. async fn signup_success(session: Session, receipt: Option>) -> impl IntoResponse { let user: User = session.get(SIGNUP_KEY).await.unwrap().unwrap_or_default(); if user != User::default() { SignupSuccessPage(user).into_response() } else { SignupErrorPage("who you?".to_string()).into_response() } } #[tokio::main] async fn main() { let session_store = MemoryStore::default(); let session_layer = SessionManagerLayer::new(session_store).with_secure(false); let app = Router::new() //.nest_service("/assets", assets_svc) .route("/signup", get(get_signup).post(post_signup)) .route("/signup_success/:receipt", get(signup_success)) .layer(session_layer); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app.into_make_service()) .await .unwrap(); } #[Error(desc = "Could not create user.")] #[non_exhaustive] pub struct CreateUserError(#[from] CreateUserErrorKind); impl IntoResponse for CreateUserError { fn into_response(self) -> Response { (StatusCode::FORBIDDEN, format!("{:?}", self.0)).into_response() } } #[Error] #[non_exhaustive] pub enum CreateUserErrorKind { AlreadyExists, #[error(desc = "Usernames must be between 1 and 20 characters long")] BadUsername, PasswordMismatch, #[error(desc = "Password must have at least 4 and at most 50 characters")] BadPassword, #[error(desc = "Display name must be less than 100 characters long")] BadDisplayname, BadEmail, BadPayment, } #[derive(Debug, Default, Deserialize, PartialEq, Eq)] pub struct SignupForm { pub username: String, #[serde(default, deserialize_with = "empty_string_as_none")] pub displayname: Option, #[serde(default, deserialize_with = "empty_string_as_none")] pub email: Option, pub password: String, pub pw_verify: String, } #[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "signup.html")] pub struct SignupPage { pub username: String, pub displayname: Option, pub email: Option, pub password: String, pub pw_verify: String, } #[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "signup_success.html")] pub struct SignupSuccessPage(pub User); #[derive(Debug, Clone, Template, Default, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "signup_error.html")] pub struct SignupErrorPage(pub String); #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct User { pub username: String, pub displayname: Option, pub email: Option, pub password: String, pub pw_verify: String, } impl Debug for User { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let pw_check = if self.password == self.pw_verify { "password matched" } else { "PASSWORD MISMATCH" }; f.debug_struct("User") .field("username", &self.username) .field("displayname", &self.displayname) .field("email", &self.email) .field("pw-check", &pw_check) .finish() } } impl Display for User { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let uname = &self.username; let dname = if let Some(ref n) = self.displayname { n } else { "" }; let email = if let Some(ref e) = self.email { e } else { "" }; write!(f, "Username: {uname}\nDisplayname: {dname}\nEmail: {email}") } } pub(crate) fn empty_string_as_none<'de, D, T>(de: D) -> Result, D::Error> where D: serde::Deserializer<'de>, T: std::str::FromStr, T::Err: std::fmt::Display, { let opt = as serde::Deserialize>::deserialize(de)?; match opt.as_deref() { None | Some("") => Ok(None), Some(s) => std::str::FromStr::from_str(s) .map_err(serde::de::Error::custom) .map(Some), } } pub(crate) fn validate_optional_length( opt: &Option, len_range: Range, err: E, ) -> Result, E> { if let Some(opt) = opt { let opt = opt.trim(); let len = opt.graphemes(true).size_hint().1.unwrap(); if !len_range.contains(&len) { Err(err) } else { Ok(Some(opt.to_string())) } } else { Ok(None) } }