simplify auth and login
This commit is contained in:
parent
c133031123
commit
dd5ae09ab8
6 changed files with 67 additions and 65 deletions
15
Cargo.lock
generated
15
Cargo.lock
generated
|
@ -948,8 +948,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
|
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if 1.0.0",
|
"cfg-if 1.0.0",
|
||||||
|
"js-sys",
|
||||||
"libc",
|
"libc",
|
||||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||||
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1521,6 +1523,18 @@ dependencies = [
|
||||||
"windows-targets 0.48.5",
|
"windows-targets 0.48.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "password-auth"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1a2a4764cc1f8d961d802af27193c6f4f0124bd0e76e8393cf818e18880f0524"
|
||||||
|
dependencies = [
|
||||||
|
"argon2",
|
||||||
|
"getrandom",
|
||||||
|
"password-hash",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "password-hash"
|
name = "password-hash"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
@ -2870,6 +2884,7 @@ dependencies = [
|
||||||
"julid-rs",
|
"julid-rs",
|
||||||
"justerror",
|
"justerror",
|
||||||
"optional_optional_user",
|
"optional_optional_user",
|
||||||
|
"password-auth",
|
||||||
"password-hash",
|
"password-hash",
|
||||||
"rand",
|
"rand",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
|
|
|
@ -23,6 +23,7 @@ clap = { version = "4", features = ["derive", "env", "unicode", "suggestions", "
|
||||||
http = "1.0.0"
|
http = "1.0.0"
|
||||||
julid-rs = "1"
|
julid-rs = "1"
|
||||||
justerror = "1"
|
justerror = "1"
|
||||||
|
password-auth = "1"
|
||||||
password-hash = { version = "0.5", features = ["std", "getrandom"] }
|
password-hash = { version = "0.5", features = ["std", "getrandom"] }
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
rand_distr = "0.4"
|
rand_distr = "0.4"
|
||||||
|
|
27
src/auth.rs
27
src/auth.rs
|
@ -1,8 +1,7 @@
|
||||||
use argon2::Argon2;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum_login::{AuthUser, AuthnBackend, UserId};
|
use axum_login::{AuthUser, AuthnBackend, UserId};
|
||||||
use julid::Julid;
|
use julid::Julid;
|
||||||
use password_hash::{PasswordHash, PasswordVerifier};
|
use password_auth::verify_password;
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use tower_sessions::{cookie::time::Duration, Expiry, SessionManagerLayer, SqliteStore};
|
use tower_sessions::{cookie::time::Duration, Expiry, SessionManagerLayer, SqliteStore};
|
||||||
|
|
||||||
|
@ -28,7 +27,14 @@ pub struct Credentials {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[Error]
|
#[Error]
|
||||||
pub struct AuthError;
|
pub struct AuthError(#[from] pub AuthErrorKind);
|
||||||
|
|
||||||
|
#[Error]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum AuthErrorKind {
|
||||||
|
Internal,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl AuthnBackend for AuthStore {
|
impl AuthnBackend for AuthStore {
|
||||||
|
@ -45,16 +51,9 @@ impl AuthnBackend for AuthStore {
|
||||||
|
|
||||||
let user = User::try_get(username, &self.0)
|
let user = User::try_get(username, &self.0)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| AuthError)?;
|
.map_err(|_| AuthErrorKind::Internal)?;
|
||||||
let verifier = Argon2::default();
|
|
||||||
let hash = PasswordHash::new(&user.pwhash).map_err(|_| AuthError)?;
|
Ok(user.filter(|user| verify_password(password, &user.pwhash).is_ok()))
|
||||||
Ok(
|
|
||||||
if verifier.verify_password(password.as_bytes(), &hash).is_ok() {
|
|
||||||
Some(user)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
|
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
|
||||||
|
@ -62,7 +61,7 @@ impl AuthnBackend for AuthStore {
|
||||||
.bind(user_id)
|
.bind(user_id)
|
||||||
.fetch_optional(&self.0)
|
.fetch_optional(&self.0)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| AuthError)
|
.map_err(|_| AuthErrorKind::Unknown.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
57
src/login.rs
57
src/login.rs
|
@ -5,34 +5,24 @@ use axum::{
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{auth::Credentials, AuthSession, LoginPage, LogoutPage, LogoutSuccessPage};
|
use crate::{
|
||||||
|
auth::{AuthError, AuthErrorKind, Credentials},
|
||||||
|
AuthSession, LoginPage, LogoutPage, LogoutSuccessPage,
|
||||||
|
};
|
||||||
|
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
// Login error and success types
|
// Login error and success types
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
|
|
||||||
#[Error]
|
impl IntoResponse for AuthError {
|
||||||
pub struct LoginError(#[from] LoginErrorKind);
|
|
||||||
|
|
||||||
#[Error]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum LoginErrorKind {
|
|
||||||
Internal,
|
|
||||||
BadPassword,
|
|
||||||
Unknown,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoResponse for LoginError {
|
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
match self.0 {
|
match self.0 {
|
||||||
LoginErrorKind::Internal => (
|
AuthErrorKind::Internal => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"An unknown error occurred; you cursed, brah?",
|
"An unknown error occurred; you cursed, brah?",
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
LoginErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(),
|
AuthErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(),
|
||||||
// we don't say it's a bad password, we just silently fail
|
|
||||||
_ => (StatusCode::OK, format!("{self}")).into_response(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,25 +53,22 @@ impl From<LoginPostForm> for Credentials {
|
||||||
pub async fn post_login(
|
pub async fn post_login(
|
||||||
mut auth: AuthSession,
|
mut auth: AuthSession,
|
||||||
Form(mut login_form): Form<LoginPostForm>,
|
Form(mut login_form): Form<LoginPostForm>,
|
||||||
) -> Result<impl IntoResponse, LoginError> {
|
) -> Result<impl IntoResponse, AuthError> {
|
||||||
let dest = login_form.destination.take();
|
let dest = login_form.destination.take();
|
||||||
let user = auth.authenticate(login_form.into()).await.map_err(|e| {
|
let user = match auth.authenticate(login_form.clone().into()).await {
|
||||||
tracing::debug!("{e}");
|
Ok(Some(user)) => user,
|
||||||
LoginErrorKind::Unknown
|
Ok(None) => return Ok(LoginPage::default().into_response()),
|
||||||
})?;
|
Err(_) => return Err(AuthErrorKind::Internal.into()),
|
||||||
|
};
|
||||||
|
|
||||||
match user {
|
if auth.login(&user).await.is_err() {
|
||||||
Some(user) => {
|
return Err(AuthErrorKind::Internal.into());
|
||||||
// log them in and set a session cookie
|
|
||||||
auth.login(&user)
|
|
||||||
.await
|
|
||||||
.map_err(|_| LoginErrorKind::Internal)?;
|
|
||||||
match dest {
|
|
||||||
Some(dest) => Ok(Redirect::to(&dest)),
|
|
||||||
_ => Ok(Redirect::to("/")),
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
_ => Err(LoginErrorKind::BadPassword.into()),
|
if let Some(ref next) = dest {
|
||||||
|
Ok(Redirect::to(next).into_response())
|
||||||
|
} else {
|
||||||
|
Ok(Redirect::to("/").into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +85,7 @@ pub async fn post_logout(mut auth: AuthSession) -> impl IntoResponse {
|
||||||
Ok(_) => LogoutSuccessPage.into_response(),
|
Ok(_) => LogoutSuccessPage.into_response(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::debug!("{e}");
|
tracing::debug!("{e}");
|
||||||
let e: LoginError = LoginErrorKind::Internal.into();
|
let e: AuthError = AuthErrorKind::Internal.into();
|
||||||
e.into_response()
|
e.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -259,7 +246,7 @@ mod test {
|
||||||
|
|
||||||
let user = User::try_get("test_user", &db).await.unwrap();
|
let user = User::try_get("test_user", &db).await.unwrap();
|
||||||
|
|
||||||
let logged_in = MainPage { user: Some(user) }.to_string();
|
let logged_in = MainPage { user }.to_string();
|
||||||
|
|
||||||
let main_page = s.get("/").await;
|
let main_page = s.get("/").await;
|
||||||
let body = std::str::from_utf8(main_page.as_bytes()).unwrap();
|
let body = std::str::from_utf8(main_page.as_bytes()).unwrap();
|
||||||
|
|
|
@ -259,7 +259,7 @@ mod test {
|
||||||
assert_eq!(StatusCode::SEE_OTHER, resp.status_code());
|
assert_eq!(StatusCode::SEE_OTHER, resp.status_code());
|
||||||
|
|
||||||
// get the new user from the db
|
// get the new user from the db
|
||||||
let user = User::try_get("good_user", &pool).await.unwrap();
|
let user = User::try_get("good_user", &pool).await.unwrap().unwrap();
|
||||||
let id = user.id;
|
let id = user.id;
|
||||||
|
|
||||||
let path = format!("/signup_success/{id}");
|
let path = format!("/signup_success/{id}");
|
||||||
|
@ -308,7 +308,7 @@ mod test {
|
||||||
|
|
||||||
// no user in db
|
// no user in db
|
||||||
let user = User::try_get("bad_user", &pool).await;
|
let user = User::try_get("bad_user", &pool).await;
|
||||||
assert!(user.is_err());
|
assert!(user.is_ok() && user.unwrap().is_none());
|
||||||
|
|
||||||
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
||||||
let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string();
|
let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string();
|
||||||
|
@ -334,7 +334,7 @@ mod test {
|
||||||
|
|
||||||
// no user in db
|
// no user in db
|
||||||
let user = User::try_get("bad_user", &pool).await;
|
let user = User::try_get("bad_user", &pool).await;
|
||||||
assert!(user.is_err());
|
assert!(user.is_ok() && user.unwrap().is_none());
|
||||||
|
|
||||||
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
||||||
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
|
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
|
||||||
|
@ -360,7 +360,7 @@ mod test {
|
||||||
|
|
||||||
// no user in db
|
// no user in db
|
||||||
let user = User::try_get("bad_user", &pool).await;
|
let user = User::try_get("bad_user", &pool).await;
|
||||||
assert!(user.is_err());
|
assert!(user.is_ok() && user.unwrap().is_none());
|
||||||
|
|
||||||
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
||||||
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
|
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
|
||||||
|
@ -371,7 +371,8 @@ mod test {
|
||||||
#[test]
|
#[test]
|
||||||
fn multibyte_password_too_short() {
|
fn multibyte_password_too_short() {
|
||||||
let pw = "🤡";
|
let pw = "🤡";
|
||||||
// min length is 4
|
// min length is 4 distinct graphemes; this is one grapheme that is four bytes,
|
||||||
|
// so it's not valid
|
||||||
assert_eq!(pw.len(), 4);
|
assert_eq!(pw.len(), 4);
|
||||||
|
|
||||||
let pool = get_db_pool();
|
let pool = get_db_pool();
|
||||||
|
@ -379,9 +380,8 @@ mod test {
|
||||||
|
|
||||||
rt.block_on(async {
|
rt.block_on(async {
|
||||||
let server = server_with_pool(&pool).await;
|
let server = server_with_pool(&pool).await;
|
||||||
let form = format!(
|
let form =
|
||||||
"username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}"
|
format!("username=bad_user&displayname=Test+User&password={pw}&pw_verify={pw}");
|
||||||
);
|
|
||||||
let body = massage(&form);
|
let body = massage(&form);
|
||||||
|
|
||||||
let resp = server
|
let resp = server
|
||||||
|
@ -394,7 +394,7 @@ mod test {
|
||||||
|
|
||||||
// no user in db
|
// no user in db
|
||||||
let user = User::try_get("bad_user", &pool).await;
|
let user = User::try_get("bad_user", &pool).await;
|
||||||
assert!(user.is_err());
|
assert!(user.is_ok() && user.unwrap().is_none());
|
||||||
|
|
||||||
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
||||||
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
|
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
|
||||||
|
@ -420,7 +420,7 @@ mod test {
|
||||||
|
|
||||||
// no user in db
|
// no user in db
|
||||||
let user = User::try_get("bad_user", &pool).await;
|
let user = User::try_get("bad_user", &pool).await;
|
||||||
assert!(user.is_err());
|
assert!(user.is_ok() && user.unwrap().is_none());
|
||||||
|
|
||||||
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
||||||
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
|
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
|
||||||
|
@ -446,7 +446,7 @@ mod test {
|
||||||
|
|
||||||
// no user in db
|
// no user in db
|
||||||
let user = User::try_get("bad_user", &pool).await;
|
let user = User::try_get("bad_user", &pool).await;
|
||||||
assert!(user.is_err());
|
assert!(user.is_ok() && user.unwrap().is_none());
|
||||||
|
|
||||||
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
||||||
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
|
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
|
||||||
|
@ -506,7 +506,7 @@ mod test {
|
||||||
|
|
||||||
// no user in db
|
// no user in db
|
||||||
let user = User::try_get("bad_user", &pool).await;
|
let user = User::try_get("bad_user", &pool).await;
|
||||||
assert!(user.is_err());
|
assert!(user.is_ok() && user.unwrap().is_none());
|
||||||
|
|
||||||
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
let body = std::str::from_utf8(resp.as_bytes()).unwrap();
|
||||||
let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string();
|
let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string();
|
||||||
|
@ -522,7 +522,7 @@ mod test {
|
||||||
rt.block_on(async {
|
rt.block_on(async {
|
||||||
let server = server_with_pool(&pool).await;
|
let server = server_with_pool(&pool).await;
|
||||||
|
|
||||||
let path = format!("/signup_success/nope");
|
let path = "/signup_success/nope";
|
||||||
|
|
||||||
let resp = server.get(&path).expect_failure().await;
|
let resp = server.get(&path).expect_failure().await;
|
||||||
assert_eq!(resp.status_code(), StatusCode::SEE_OTHER);
|
assert_eq!(resp.status_code(), StatusCode::SEE_OTHER);
|
||||||
|
|
|
@ -68,10 +68,10 @@ impl Display for User {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl User {
|
impl User {
|
||||||
pub async fn try_get(username: &str, db: &SqlitePool) -> Result<Self, sqlx::Error> {
|
pub async fn try_get(username: &str, db: &SqlitePool) -> Result<Option<Self>, sqlx::Error> {
|
||||||
sqlx::query_as(USERNAME_QUERY)
|
sqlx::query_as(USERNAME_QUERY)
|
||||||
.bind(username)
|
.bind(username)
|
||||||
.fetch_one(db)
|
.fetch_optional(db)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue