use std::fmt::Display; use argon2::{ password_hash::{rand_core::OsRng, PasswordHasher, SaltString}, Argon2, }; use askama::Template; use axum::{ extract::{Form, Path, State}, http::StatusCode, response::{IntoResponse, Response}, }; use axum_login::{secrecy::SecretVec, AuthUser, SqliteStore}; use sqlx::{query_as, SqlitePool}; use unicode_segmentation::UnicodeSegmentation; use uuid::Uuid; use crate::templates::{CreateUser, LoginGet}; const CREATE_QUERY: &str = "insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)"; const ID_QUERY: &str = "select * from witches where id = $1"; // const PW_QUERY: &str = "select pwhash from witches where id = $1"; #[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow)] pub struct User { pub id: Uuid, pub username: String, pub displayname: Option, pub email: Option, pub last_seen: Option, pwhash: String, } impl Display for User { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let uname = &self.username; let dname = if let Some(ref n) = self.displayname { n } else { "" }; let email = if let Some(ref e) = self.email { e } else { "" }; write!(f, "Username: {uname}\nDisplayname: {dname}\nEmail: {email}") } } pub type AuthContext = axum_login::extractors::AuthContext>; impl AuthUser for User { fn get_id(&self) -> Uuid { self.id } fn get_password_hash(&self) -> SecretVec { SecretVec::new(self.pwhash.as_bytes().to_vec()) } } //-------------------------------------------------------------------------- // Result types for user creation //-------------------------------------------------------------------------- #[derive(Debug, Clone, Template)] #[template(path = "signup_success.html")] pub struct CreateUserSuccess(User); #[Error(desc = "Could not create user.")] #[non_exhaustive] pub struct CreateUserError(#[from] CreateUserErrorKind); impl IntoResponse for CreateUserError { fn into_response(self) -> askama_axum::Response { match self.0 { CreateUserErrorKind::UnknownDBError => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{self}")).into_response() } _ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(), } } } #[Error] #[non_exhaustive] pub enum CreateUserErrorKind { AlreadyExists, #[error(desc = "Usernames must be between 1 and 20 non-whitespace characters long")] BadUsername, PasswordMismatch, #[error(desc = "Password must have at least 4 and at most 50 characters")] BadPassword, #[error(desc = "Display name must be less than 100 characters long")] BadDisplayname, BadEmail, MissingFields, UnknownDBError, } //-------------------------------------------------------------------------- // User creation route handlers //-------------------------------------------------------------------------- /// Get Handler: displays the form to create a user pub async fn get_create_user() -> CreateUser { CreateUser::default() } /// Post Handler: validates form values and calls the actual, private user /// creation function #[axum::debug_handler] pub async fn post_create_user( State(pool): State, Form(signup): Form, ) -> Result { let username = &signup.username; let displayname = &signup.displayname; let email = &signup.email; let password = &signup.password; let verify = &signup.pw_verify; let username = username.trim(); let name_len = username.graphemes(true).size_hint().1.unwrap(); // we are not ascii exclusivists around here if !(1..=20).contains(&name_len) { return Err(CreateUserErrorKind::BadUsername.into()); } if password != verify { return Err(CreateUserErrorKind::PasswordMismatch.into()); } let password = urlencoding::decode(password) .map_err(|_| CreateUserErrorKind::BadPassword)? .to_string(); let password = password.trim(); let password = password.as_bytes(); if !(4..=50).contains(&password.len()) { return Err(CreateUserErrorKind::BadPassword.into()); } let displayname = if let Some(dn) = displayname { let dn = urlencoding::decode(dn) .map_err(|_| CreateUserErrorKind::BadDisplayname)? .to_string() .trim() .to_string(); if dn.graphemes(true).size_hint().1.unwrap() > 100 { return Err(CreateUserErrorKind::BadDisplayname.into()); } Some(dn) } else { None }; let displayname = &displayname; // TODO(2023-05-17): validate email let email = if let Some(email) = email { let email = urlencoding::decode(email) .map_err(|_| CreateUserErrorKind::BadEmail)? .to_string(); Some(email) } else { None }; let email = &email; let user = create_user(username, displayname, email, password, &pool).await?; tracing::debug!("created {user:?}"); let id = user.id.as_simple().to_string(); let location = format!("/signup_success/{id}"); let resp = axum::response::Redirect::temporary(&location).into_response(); Ok(resp) } /// Get handler for successful signup pub async fn handle_signup_success( Path(id): Path, State(pool): State, ) -> Response { let user: User = { let id = id.trim(); let id = Uuid::try_parse(id).unwrap_or_default(); query_as(ID_QUERY) .bind(id) .fetch_one(&pool) .await .unwrap_or_default() }; let mut resp = CreateUserSuccess(user.clone()).into_response(); if user.username.is_empty() || id.is_empty() { // redirect to front page if we got here without a valid witch ID *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; resp.headers_mut().insert("Location", "/".parse().unwrap()); } resp } //-------------------------------------------------------------------------- // Login error and success types //-------------------------------------------------------------------------- #[Error] pub struct LoginError(#[from] LoginErrorKind); #[Error] #[non_exhaustive] pub enum LoginErrorKind { BadPassword, Unknown, } impl IntoResponse for LoginError { fn into_response(self) -> Response { match self.0 { LoginErrorKind::Unknown => ( StatusCode::INTERNAL_SERVER_ERROR, "An unknown error occurred; you cursed, brah?", ) .into_response(), _ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(), } } } //-------------------------------------------------------------------------- // Login handlers //-------------------------------------------------------------------------- /// Handle login queries #[axum::debug_handler] pub async fn post_login( mut auth: AuthContext, State(pool): State, ) -> Result<(), LoginError> { Err(LoginErrorKind::Unknown.into()) } pub async fn get_login() -> impl IntoResponse { LoginGet::default() } //------------------------------------------------------------------------- // private fns //------------------------------------------------------------------------- async fn create_user( username: &str, displayname: &Option, email: &Option, password: &[u8], pool: &SqlitePool, ) -> Result { // Argon2 with default params (Argon2id v19) let argon2 = Argon2::default(); let salt = SaltString::generate(&mut OsRng); let pwhash = argon2 .hash_password(password, &salt) .unwrap() // safe to unwrap, we know the salt is valid .to_string(); let id = Uuid::new_v4(); let res = sqlx::query(CREATE_QUERY) .bind(id) .bind(username) .bind(displayname) .bind(email) .bind(&pwhash) .execute(pool) .await; match res { Ok(_) => { let user = User { id, username: username.to_string(), displayname: displayname.to_owned(), email: email.to_owned(), last_seen: None, pwhash, }; Ok(user) } Err(sqlx::Error::Database(db)) => { if let Some(exit) = db.code() { let exit = exit.parse().unwrap_or(0u32); // https://www.sqlite.org/rescode.html codes for unique constraint violations: if exit == 2067u32 || exit == 1555 { Err(CreateUserErrorKind::AlreadyExists.into()) } else { Err(CreateUserErrorKind::UnknownDBError.into()) } } else { Err(CreateUserErrorKind::UnknownDBError.into()) } } _ => Err(CreateUserErrorKind::UnknownDBError.into()), } }