use std::{error::Error, fmt::Debug, ops::RangeInclusive}; use axum::{ extract::{Form, Path}, http::StatusCode, response::{IntoResponse, Redirect, Response}, }; use serde::Deserialize; use tower_sessions::Session; use unicode_segmentation::UnicodeSegmentation; use crate::{templates::*, User}; const SIGNUP_KEY: &str = "meow"; const PASSWORD_LEN: RangeInclusive<usize> = 4..=100; const USERNAME_LEN: RangeInclusive<usize> = 1..=50; const DISPLAYNAME_LEN: RangeInclusive<usize> = 0..=100; const EMAIL_LEN: RangeInclusive<usize> = 4..=50; #[Error(desc = "Could not create user.")] #[non_exhaustive] pub struct CreateUserError(#[from] CreateUserErrorKind); impl IntoResponse for CreateUserError { fn into_response(self) -> Response { (StatusCode::FORBIDDEN, format!("{:?}", self.0)).into_response() } } #[Error] #[non_exhaustive] pub enum CreateUserErrorKind { AlreadyExists, #[error(desc = "Usernames must be between 1 and 50 characters long")] BadUsername, PasswordMismatch, #[error(desc = "Password must have at least 4 and at most 100 characters")] BadPassword, #[error(desc = "Display name must be less than 100 characters long")] BadDisplayname, BadEmail, BadPayment, } #[derive(Debug, Default, Deserialize, PartialEq, Eq)] pub struct SignupForm { pub username: String, #[serde(default, deserialize_with = "empty_string_as_none")] pub displayname: Option<String>, pub email: String, pub password: String, pub pw_verify: String, } /// Displays the signup form. pub async fn get_signup() -> impl IntoResponse { SignupPage { ..Default::default() } } /// Receives the form with the user signup fields filled out. pub async fn post_signup( session: Session, Form(form): Form<SignupForm>, ) -> Result<impl IntoResponse, CreateUserError> { let user = validate_signup(&form).await?; session.insert(SIGNUP_KEY, user).await.unwrap(); Ok(Redirect::to( "https://buy.stripe.com/test_eVa6rrb7ygjNbwk000", )) } pub async fn get_edit_signup( session: Session, receipt: Option<Path<String>>, ) -> Result<impl IntoResponse, CreateUserError> { Ok(()) } pub async fn post_edit_signup( session: Session, Form(form): Form<SignupForm>, ) -> Result<impl IntoResponse, CreateUserError> { Ok(()) } /// Called from Stripe with the receipt of payment. pub async fn signup_success(session: Session, receipt: Option<Path<String>>) -> impl IntoResponse { let user: User = session.get(SIGNUP_KEY).await.unwrap().unwrap_or_default(); if user == User::default() { return SignupErrorPage("who you?".to_string()).into_response(); } // TODO: check Stripe for the receipt, verify it's legit SignupSuccessPage(user).into_response() } //-************************************************************************ // helpers //-************************************************************************ async fn validate_signup(form: &SignupForm) -> Result<User, CreateUserError> { let username = form.username.trim(); let password = form.password.trim(); let verify = form.pw_verify.trim(); let name_len = username.graphemes(true).size_hint().1.unwrap(); // we are not ascii exclusivists around here if !USERNAME_LEN.contains(&name_len) { return Err(CreateUserErrorKind::BadUsername.into()); } if password != verify { return Err(CreateUserErrorKind::PasswordMismatch.into()); } let pwlen = password.graphemes(true).size_hint().1.unwrap_or(0); if !PASSWORD_LEN.contains(&pwlen) { return Err(CreateUserErrorKind::BadPassword.into()); } // clean up the optionals let displayname = validate_optional_length( &form.displayname, DISPLAYNAME_LEN, CreateUserErrorKind::BadDisplayname, )?; let email = validate_length(&form.email, EMAIL_LEN, CreateUserErrorKind::BadEmail)?; let user = User { username: username.to_string(), displayname, email, password: password.to_string(), pw_verify: verify.to_string(), }; Ok(user) } pub(crate) fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error> where D: serde::Deserializer<'de>, T: std::str::FromStr, T::Err: std::fmt::Display, { let opt = <Option<String> as serde::Deserialize>::deserialize(de)?; match opt.as_deref() { None | Some("") => Ok(None), Some(s) => std::str::FromStr::from_str(s) .map_err(serde::de::Error::custom) .map(Some), } } pub(crate) fn validate_optional_length<E: Error>( opt: &Option<String>, len_range: RangeInclusive<usize>, err: E, ) -> Result<Option<String>, E> { if let Some(opt) = opt { let opt = opt.trim(); let len = opt.graphemes(true).size_hint().1.unwrap(); if !len_range.contains(&len) { Err(err) } else { Ok(Some(opt.to_string())) } } else { Ok(None) } } pub(crate) fn validate_length<E: Error>( thing: &str, len_range: RangeInclusive<usize>, err: E, ) -> Result<String, E> { let thing = thing.trim(); let len = thing.graphemes(true).size_hint().1.unwrap(); if !len_range.contains(&len) { Err(err) } else { Ok(thing.to_string()) } }