simplify auth and login

This commit is contained in:
Joe Ardent 2023-12-24 16:21:55 -08:00
parent c133031123
commit dd5ae09ab8
6 changed files with 67 additions and 65 deletions

15
Cargo.lock generated
View file

@ -948,8 +948,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
dependencies = [ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"js-sys",
"libc", "libc",
"wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
] ]
[[package]] [[package]]
@ -1521,6 +1523,18 @@ dependencies = [
"windows-targets 0.48.5", "windows-targets 0.48.5",
] ]
[[package]]
name = "password-auth"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a2a4764cc1f8d961d802af27193c6f4f0124bd0e76e8393cf818e18880f0524"
dependencies = [
"argon2",
"getrandom",
"password-hash",
"rand_core",
]
[[package]] [[package]]
name = "password-hash" name = "password-hash"
version = "0.5.0" version = "0.5.0"
@ -2870,6 +2884,7 @@ dependencies = [
"julid-rs", "julid-rs",
"justerror", "justerror",
"optional_optional_user", "optional_optional_user",
"password-auth",
"password-hash", "password-hash",
"rand", "rand",
"rand_distr", "rand_distr",

View file

@ -23,6 +23,7 @@ clap = { version = "4", features = ["derive", "env", "unicode", "suggestions", "
http = "1.0.0" http = "1.0.0"
julid-rs = "1" julid-rs = "1"
justerror = "1" justerror = "1"
password-auth = "1"
password-hash = { version = "0.5", features = ["std", "getrandom"] } password-hash = { version = "0.5", features = ["std", "getrandom"] }
rand = "0.8" rand = "0.8"
rand_distr = "0.4" rand_distr = "0.4"

View file

@ -1,8 +1,7 @@
use argon2::Argon2;
use async_trait::async_trait; use async_trait::async_trait;
use axum_login::{AuthUser, AuthnBackend, UserId}; use axum_login::{AuthUser, AuthnBackend, UserId};
use julid::Julid; use julid::Julid;
use password_hash::{PasswordHash, PasswordVerifier}; use password_auth::verify_password;
use sqlx::SqlitePool; use sqlx::SqlitePool;
use tower_sessions::{cookie::time::Duration, Expiry, SessionManagerLayer, SqliteStore}; use tower_sessions::{cookie::time::Duration, Expiry, SessionManagerLayer, SqliteStore};
@ -28,7 +27,14 @@ pub struct Credentials {
} }
#[Error] #[Error]
pub struct AuthError; pub struct AuthError(#[from] pub AuthErrorKind);
#[Error]
#[non_exhaustive]
pub enum AuthErrorKind {
Internal,
Unknown,
}
#[async_trait] #[async_trait]
impl AuthnBackend for AuthStore { impl AuthnBackend for AuthStore {
@ -45,16 +51,9 @@ impl AuthnBackend for AuthStore {
let user = User::try_get(username, &self.0) let user = User::try_get(username, &self.0)
.await .await
.map_err(|_| AuthError)?; .map_err(|_| AuthErrorKind::Internal)?;
let verifier = Argon2::default();
let hash = PasswordHash::new(&user.pwhash).map_err(|_| AuthError)?; Ok(user.filter(|user| verify_password(password, &user.pwhash).is_ok()))
Ok(
if verifier.verify_password(password.as_bytes(), &hash).is_ok() {
Some(user)
} else {
None
},
)
} }
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> { async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
@ -62,7 +61,7 @@ impl AuthnBackend for AuthStore {
.bind(user_id) .bind(user_id)
.fetch_optional(&self.0) .fetch_optional(&self.0)
.await .await
.map_err(|_| AuthError) .map_err(|_| AuthErrorKind::Unknown.into())
} }
} }

View file

@ -5,34 +5,24 @@ use axum::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{auth::Credentials, AuthSession, LoginPage, LogoutPage, LogoutSuccessPage}; use crate::{
auth::{AuthError, AuthErrorKind, Credentials},
AuthSession, LoginPage, LogoutPage, LogoutSuccessPage,
};
//-************************************************************************ //-************************************************************************
// Login error and success types // Login error and success types
//-************************************************************************ //-************************************************************************
#[Error] impl IntoResponse for AuthError {
pub struct LoginError(#[from] LoginErrorKind);
#[Error]
#[non_exhaustive]
pub enum LoginErrorKind {
Internal,
BadPassword,
Unknown,
}
impl IntoResponse for LoginError {
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self.0 { match self.0 {
LoginErrorKind::Internal => ( AuthErrorKind::Internal => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"An unknown error occurred; you cursed, brah?", "An unknown error occurred; you cursed, brah?",
) )
.into_response(), .into_response(),
LoginErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(), AuthErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(),
// we don't say it's a bad password, we just silently fail
_ => (StatusCode::OK, format!("{self}")).into_response(),
} }
} }
} }
@ -63,25 +53,22 @@ impl From<LoginPostForm> for Credentials {
pub async fn post_login( pub async fn post_login(
mut auth: AuthSession, mut auth: AuthSession,
Form(mut login_form): Form<LoginPostForm>, Form(mut login_form): Form<LoginPostForm>,
) -> Result<impl IntoResponse, LoginError> { ) -> Result<impl IntoResponse, AuthError> {
let dest = login_form.destination.take(); let dest = login_form.destination.take();
let user = auth.authenticate(login_form.into()).await.map_err(|e| { let user = match auth.authenticate(login_form.clone().into()).await {
tracing::debug!("{e}"); Ok(Some(user)) => user,
LoginErrorKind::Unknown Ok(None) => return Ok(LoginPage::default().into_response()),
})?; Err(_) => return Err(AuthErrorKind::Internal.into()),
};
match user { if auth.login(&user).await.is_err() {
Some(user) => { return Err(AuthErrorKind::Internal.into());
// log them in and set a session cookie
auth.login(&user)
.await
.map_err(|_| LoginErrorKind::Internal)?;
match dest {
Some(dest) => Ok(Redirect::to(&dest)),
_ => Ok(Redirect::to("/")),
} }
}
_ => Err(LoginErrorKind::BadPassword.into()), if let Some(ref next) = dest {
Ok(Redirect::to(next).into_response())
} else {
Ok(Redirect::to("/").into_response())
} }
} }
@ -98,7 +85,7 @@ pub async fn post_logout(mut auth: AuthSession) -> impl IntoResponse {
Ok(_) => LogoutSuccessPage.into_response(), Ok(_) => LogoutSuccessPage.into_response(),
Err(e) => { Err(e) => {
tracing::debug!("{e}"); tracing::debug!("{e}");
let e: LoginError = LoginErrorKind::Internal.into(); let e: AuthError = AuthErrorKind::Internal.into();
e.into_response() e.into_response()
} }
} }
@ -259,7 +246,7 @@ mod test {
let user = User::try_get("test_user", &db).await.unwrap(); let user = User::try_get("test_user", &db).await.unwrap();
let logged_in = MainPage { user: Some(user) }.to_string(); let logged_in = MainPage { user }.to_string();
let main_page = s.get("/").await; let main_page = s.get("/").await;
let body = std::str::from_utf8(main_page.as_bytes()).unwrap(); let body = std::str::from_utf8(main_page.as_bytes()).unwrap();

View file

@ -259,7 +259,7 @@ mod test {
assert_eq!(StatusCode::SEE_OTHER, resp.status_code()); assert_eq!(StatusCode::SEE_OTHER, resp.status_code());
// get the new user from the db // get the new user from the db
let user = User::try_get("good_user", &pool).await.unwrap(); let user = User::try_get("good_user", &pool).await.unwrap().unwrap();
let id = user.id; let id = user.id;
let path = format!("/signup_success/{id}"); let path = format!("/signup_success/{id}");
@ -308,7 +308,7 @@ mod test {
// no user in db // no user in db
let user = User::try_get("bad_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_ok() && user.unwrap().is_none());
let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let body = std::str::from_utf8(resp.as_bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string(); let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string();
@ -334,7 +334,7 @@ mod test {
// no user in db // no user in db
let user = User::try_get("bad_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_ok() && user.unwrap().is_none());
let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let body = std::str::from_utf8(resp.as_bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
@ -360,7 +360,7 @@ mod test {
// no user in db // no user in db
let user = User::try_get("bad_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_ok() && user.unwrap().is_none());
let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let body = std::str::from_utf8(resp.as_bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
@ -371,7 +371,8 @@ mod test {
#[test] #[test]
fn multibyte_password_too_short() { fn multibyte_password_too_short() {
let pw = "🤡"; let pw = "🤡";
// min length is 4 // min length is 4 distinct graphemes; this is one grapheme that is four bytes,
// so it's not valid
assert_eq!(pw.len(), 4); assert_eq!(pw.len(), 4);
let pool = get_db_pool(); let pool = get_db_pool();
@ -379,9 +380,8 @@ mod test {
rt.block_on(async { rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let form = format!( let form =
"username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}" format!("username=bad_user&displayname=Test+User&password={pw}&pw_verify={pw}");
);
let body = massage(&form); let body = massage(&form);
let resp = server let resp = server
@ -394,7 +394,7 @@ mod test {
// no user in db // no user in db
let user = User::try_get("bad_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_ok() && user.unwrap().is_none());
let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let body = std::str::from_utf8(resp.as_bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
@ -420,7 +420,7 @@ mod test {
// no user in db // no user in db
let user = User::try_get("bad_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_ok() && user.unwrap().is_none());
let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let body = std::str::from_utf8(resp.as_bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
@ -446,7 +446,7 @@ mod test {
// no user in db // no user in db
let user = User::try_get("bad_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_ok() && user.unwrap().is_none());
let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let body = std::str::from_utf8(resp.as_bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
@ -506,7 +506,7 @@ mod test {
// no user in db // no user in db
let user = User::try_get("bad_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_ok() && user.unwrap().is_none());
let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let body = std::str::from_utf8(resp.as_bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string();
@ -522,7 +522,7 @@ mod test {
rt.block_on(async { rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let path = format!("/signup_success/nope"); let path = "/signup_success/nope";
let resp = server.get(&path).expect_failure().await; let resp = server.get(&path).expect_failure().await;
assert_eq!(resp.status_code(), StatusCode::SEE_OTHER); assert_eq!(resp.status_code(), StatusCode::SEE_OTHER);

View file

@ -68,10 +68,10 @@ impl Display for User {
} }
impl User { impl User {
pub async fn try_get(username: &str, db: &SqlitePool) -> Result<Self, sqlx::Error> { pub async fn try_get(username: &str, db: &SqlitePool) -> Result<Option<Self>, sqlx::Error> {
sqlx::query_as(USERNAME_QUERY) sqlx::query_as(USERNAME_QUERY)
.bind(username) .bind(username)
.fetch_one(db) .fetch_optional(db)
.await .await
} }