diff --git a/Cargo.lock b/Cargo.lock index 6a426ed..b3e0b6a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -948,8 +948,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1521,6 +1523,18 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "password-auth" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a2a4764cc1f8d961d802af27193c6f4f0124bd0e76e8393cf818e18880f0524" +dependencies = [ + "argon2", + "getrandom", + "password-hash", + "rand_core", +] + [[package]] name = "password-hash" version = "0.5.0" @@ -2870,6 +2884,7 @@ dependencies = [ "julid-rs", "justerror", "optional_optional_user", + "password-auth", "password-hash", "rand", "rand_distr", diff --git a/Cargo.toml b/Cargo.toml index cfe8238..59fa8d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ clap = { version = "4", features = ["derive", "env", "unicode", "suggestions", " http = "1.0.0" julid-rs = "1" justerror = "1" +password-auth = "1" password-hash = { version = "0.5", features = ["std", "getrandom"] } rand = "0.8" rand_distr = "0.4" diff --git a/src/auth.rs b/src/auth.rs index ae11a6e..00de6e8 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,8 +1,7 @@ -use argon2::Argon2; use async_trait::async_trait; use axum_login::{AuthUser, AuthnBackend, UserId}; use julid::Julid; -use password_hash::{PasswordHash, PasswordVerifier}; +use password_auth::verify_password; use sqlx::SqlitePool; use tower_sessions::{cookie::time::Duration, Expiry, SessionManagerLayer, SqliteStore}; @@ -28,7 +27,14 @@ pub struct Credentials { } #[Error] -pub struct AuthError; +pub struct AuthError(#[from] pub AuthErrorKind); + +#[Error] +#[non_exhaustive] +pub enum AuthErrorKind { + Internal, + Unknown, +} #[async_trait] impl AuthnBackend for AuthStore { @@ -45,16 +51,9 @@ impl AuthnBackend for AuthStore { let user = User::try_get(username, &self.0) .await - .map_err(|_| AuthError)?; - let verifier = Argon2::default(); - let hash = PasswordHash::new(&user.pwhash).map_err(|_| AuthError)?; - Ok( - if verifier.verify_password(password.as_bytes(), &hash).is_ok() { - Some(user) - } else { - None - }, - ) + .map_err(|_| AuthErrorKind::Internal)?; + + Ok(user.filter(|user| verify_password(password, &user.pwhash).is_ok())) } async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { @@ -62,7 +61,7 @@ impl AuthnBackend for AuthStore { .bind(user_id) .fetch_optional(&self.0) .await - .map_err(|_| AuthError) + .map_err(|_| AuthErrorKind::Unknown.into()) } } diff --git a/src/login.rs b/src/login.rs index d1ce45a..0f5d723 100644 --- a/src/login.rs +++ b/src/login.rs @@ -5,34 +5,24 @@ use axum::{ }; use serde::{Deserialize, Serialize}; -use crate::{auth::Credentials, AuthSession, LoginPage, LogoutPage, LogoutSuccessPage}; +use crate::{ + auth::{AuthError, AuthErrorKind, Credentials}, + AuthSession, LoginPage, LogoutPage, LogoutSuccessPage, +}; //-************************************************************************ // Login error and success types //-************************************************************************ -#[Error] -pub struct LoginError(#[from] LoginErrorKind); - -#[Error] -#[non_exhaustive] -pub enum LoginErrorKind { - Internal, - BadPassword, - Unknown, -} - -impl IntoResponse for LoginError { +impl IntoResponse for AuthError { fn into_response(self) -> Response { match self.0 { - LoginErrorKind::Internal => ( + AuthErrorKind::Internal => ( StatusCode::INTERNAL_SERVER_ERROR, "An unknown error occurred; you cursed, brah?", ) .into_response(), - LoginErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(), - // we don't say it's a bad password, we just silently fail - _ => (StatusCode::OK, format!("{self}")).into_response(), + AuthErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(), } } } @@ -63,25 +53,22 @@ impl From for Credentials { pub async fn post_login( mut auth: AuthSession, Form(mut login_form): Form, -) -> Result { +) -> Result { let dest = login_form.destination.take(); - let user = auth.authenticate(login_form.into()).await.map_err(|e| { - tracing::debug!("{e}"); - LoginErrorKind::Unknown - })?; + let user = match auth.authenticate(login_form.clone().into()).await { + Ok(Some(user)) => user, + Ok(None) => return Ok(LoginPage::default().into_response()), + Err(_) => return Err(AuthErrorKind::Internal.into()), + }; - match user { - Some(user) => { - // log them in and set a session cookie - auth.login(&user) - .await - .map_err(|_| LoginErrorKind::Internal)?; - match dest { - Some(dest) => Ok(Redirect::to(&dest)), - _ => Ok(Redirect::to("/")), - } - } - _ => Err(LoginErrorKind::BadPassword.into()), + if auth.login(&user).await.is_err() { + return Err(AuthErrorKind::Internal.into()); + } + + if let Some(ref next) = dest { + Ok(Redirect::to(next).into_response()) + } else { + Ok(Redirect::to("/").into_response()) } } @@ -98,7 +85,7 @@ pub async fn post_logout(mut auth: AuthSession) -> impl IntoResponse { Ok(_) => LogoutSuccessPage.into_response(), Err(e) => { tracing::debug!("{e}"); - let e: LoginError = LoginErrorKind::Internal.into(); + let e: AuthError = AuthErrorKind::Internal.into(); e.into_response() } } @@ -259,7 +246,7 @@ mod test { let user = User::try_get("test_user", &db).await.unwrap(); - let logged_in = MainPage { user: Some(user) }.to_string(); + let logged_in = MainPage { user }.to_string(); let main_page = s.get("/").await; let body = std::str::from_utf8(main_page.as_bytes()).unwrap(); diff --git a/src/signup.rs b/src/signup.rs index 27ae94c..869cde1 100644 --- a/src/signup.rs +++ b/src/signup.rs @@ -259,7 +259,7 @@ mod test { assert_eq!(StatusCode::SEE_OTHER, resp.status_code()); // get the new user from the db - let user = User::try_get("good_user", &pool).await.unwrap(); + let user = User::try_get("good_user", &pool).await.unwrap().unwrap(); let id = user.id; let path = format!("/signup_success/{id}"); @@ -308,7 +308,7 @@ mod test { // no user in db let user = User::try_get("bad_user", &pool).await; - assert!(user.is_err()); + assert!(user.is_ok() && user.unwrap().is_none()); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string(); @@ -334,7 +334,7 @@ mod test { // no user in db let user = User::try_get("bad_user", &pool).await; - assert!(user.is_err()); + assert!(user.is_ok() && user.unwrap().is_none()); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); @@ -360,7 +360,7 @@ mod test { // no user in db let user = User::try_get("bad_user", &pool).await; - assert!(user.is_err()); + assert!(user.is_ok() && user.unwrap().is_none()); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); @@ -371,7 +371,8 @@ mod test { #[test] fn multibyte_password_too_short() { let pw = "🤡"; - // min length is 4 + // min length is 4 distinct graphemes; this is one grapheme that is four bytes, + // so it's not valid assert_eq!(pw.len(), 4); let pool = get_db_pool(); @@ -379,9 +380,8 @@ mod test { rt.block_on(async { let server = server_with_pool(&pool).await; - let form = format!( - "username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}" - ); + let form = + format!("username=bad_user&displayname=Test+User&password={pw}&pw_verify={pw}"); let body = massage(&form); let resp = server @@ -394,7 +394,7 @@ mod test { // no user in db let user = User::try_get("bad_user", &pool).await; - assert!(user.is_err()); + assert!(user.is_ok() && user.unwrap().is_none()); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); @@ -420,7 +420,7 @@ mod test { // no user in db let user = User::try_get("bad_user", &pool).await; - assert!(user.is_err()); + assert!(user.is_ok() && user.unwrap().is_none()); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); @@ -446,7 +446,7 @@ mod test { // no user in db let user = User::try_get("bad_user", &pool).await; - assert!(user.is_err()); + assert!(user.is_ok() && user.unwrap().is_none()); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); @@ -506,7 +506,7 @@ mod test { // no user in db let user = User::try_get("bad_user", &pool).await; - assert!(user.is_err()); + assert!(user.is_ok() && user.unwrap().is_none()); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string(); @@ -522,7 +522,7 @@ mod test { rt.block_on(async { let server = server_with_pool(&pool).await; - let path = format!("/signup_success/nope"); + let path = "/signup_success/nope"; let resp = server.get(&path).expect_failure().await; assert_eq!(resp.status_code(), StatusCode::SEE_OTHER); diff --git a/src/users.rs b/src/users.rs index 85c6569..d53a479 100644 --- a/src/users.rs +++ b/src/users.rs @@ -68,10 +68,10 @@ impl Display for User { } impl User { - pub async fn try_get(username: &str, db: &SqlitePool) -> Result { + pub async fn try_get(username: &str, db: &SqlitePool) -> Result, sqlx::Error> { sqlx::query_as(USERNAME_QUERY) .bind(username) - .fetch_one(db) + .fetch_optional(db) .await }