diff --git a/src/auth.rs b/src/auth.rs index a2df43b..a22f51b 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,13 +1,9 @@ +use argon2::Argon2; use async_trait::async_trait; - use axum_login::{AuthnBackend, UserId}; - +use password_hash::{PasswordHash, PasswordVerifier}; use sqlx::SqlitePool; - -use tower_sessions::{ - cookie::time::Duration, Expiry, SessionManagerLayer, - SqliteStore, -}; +use tower_sessions::{cookie::time::Duration, Expiry, SessionManagerLayer, SqliteStore}; use crate::User; @@ -38,23 +34,48 @@ impl std::ops::Deref for AuthStore { } } +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct Credentials { + pub username: String, + pub password: String, +} + +#[Error] +pub struct AuthError; + #[async_trait] impl AuthnBackend for AuthStore { type User = User; - type Credentials = String; + type Credentials = Credentials; - type Error = sqlx::Error; + type Error = AuthError; async fn authenticate( &self, - _creds: Self::Credentials, + creds: Self::Credentials, ) -> Result, Self::Error> { - todo!() + let username = creds.username.trim(); + let password = creds.password.trim(); + let user = User::try_get(username, &self) + .await + .map_err(|_| AuthError)?; + + let verifier = Argon2::default(); + let hash = PasswordHash::new(&user.pwhash).map_err(|_| AuthError)?; + match verifier.verify_password(password.as_bytes(), &hash) { + Ok(_) => Ok(Some(user)), + _ => Ok(None), + } } - async fn get_user(&self, _user_id: &UserId) -> Result, Self::Error> { - todo!() + async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { + let user = sqlx::query_as("select * from users where id = ?") + .bind(user_id) + .fetch_optional(&self.0) + .await + .map_err(|_| AuthError)?; + Ok(user) } } diff --git a/src/login.rs b/src/login.rs index af8ca9e..6a1dc6f 100644 --- a/src/login.rs +++ b/src/login.rs @@ -1,17 +1,13 @@ -use argon2::{ - password_hash::{PasswordHash, PasswordVerifier}, - Argon2, -}; + use axum::{ - extract::State, http::StatusCode, response::{IntoResponse, Redirect, Response}, Form, }; use serde::{Deserialize, Serialize}; -use sqlx::SqlitePool; -use crate::{AuthSession, LoginPage, LogoutPage, LogoutSuccessPage, User}; + +use crate::{auth::Credentials, AuthSession, LoginPage, LogoutPage, LogoutSuccessPage}; //-************************************************************************ // Constants @@ -52,6 +48,16 @@ impl IntoResponse for LoginError { pub struct LoginPostForm { pub username: String, pub password: String, + pub destination: Option, +} + +impl From for Credentials { + fn from(value: LoginPostForm) -> Self { + Self { + username: value.username, + password: value.password, + } + } } //-************************************************************************ @@ -62,29 +68,19 @@ pub struct LoginPostForm { #[axum::debug_handler] pub async fn post_login( mut auth: AuthSession, - State(pool): State, Form(login): Form, ) -> Result { - let username = &login.username; - let username = username.trim(); - - let pw = &login.password; - let pw = pw.trim(); - - let user = User::try_get(username, &pool).await.map_err(|e| { + let user = auth.authenticate(login.into()).await.map_err(|e| { tracing::debug!("{e}"); LoginErrorKind::Unknown })?; - let verifier = Argon2::default(); - let hash = PasswordHash::new(&user.pwhash).map_err(|_| LoginErrorKind::Internal)?; - match verifier.verify_password(pw.as_bytes(), &hash) { - Ok(_) => { + match user { + Some(user) => { // log them in and set a session cookie auth.login(&user) .await .map_err(|_| LoginErrorKind::Internal)?; - Ok(Redirect::to("/")) } _ => Err(LoginErrorKind::BadPassword.into()), diff --git a/src/main.rs b/src/main.rs index 9e0ae43..4f0a871 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ use std::net::SocketAddr; -use tokio::signal; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use what2watch::get_db_pool; @@ -28,32 +27,7 @@ fn main() { let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); - //.with_graceful_shutdown(shutdown_signal()) // removed in 0.7 because of upstream dep changes - }); rt.block_on(pool.close()); } - -async fn shutdown_signal() { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - println!(" signal received, starting graceful shutdown"); -}