use argon2::{ password_hash::{PasswordHash, PasswordVerifier}, Argon2, }; use axum::{ extract::State, http::StatusCode, response::{IntoResponse, Redirect, Response}, Form, }; use sqlx::SqlitePool; use crate::{ templates::{LoginGet, LoginPost, LogoutGet, LogoutPost}, util::form_decode, AuthContext, User, }; //-************************************************************************ // Constants //-************************************************************************ //-************************************************************************ // Login error and success types //-************************************************************************ #[Error] pub struct LoginError(#[from] LoginErrorKind); #[Error] #[non_exhaustive] pub enum LoginErrorKind { Internal, BadPassword, BadUsername, Unknown, } impl IntoResponse for LoginError { fn into_response(self) -> Response { match self.0 { LoginErrorKind::Internal => ( StatusCode::INTERNAL_SERVER_ERROR, "An unknown error occurred; you cursed, brah?", ) .into_response(), LoginErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(), _ => (StatusCode::OK, format!("{self}")).into_response(), } } } //-************************************************************************ // Login handlers //-************************************************************************ /// Handle login queries #[axum::debug_handler] pub async fn post_login( mut auth: AuthContext, State(pool): State, Form(login): Form, ) -> Result { let username = form_decode(&login.username, LoginErrorKind::BadUsername)?; let username = username.trim(); let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?; let pw = pw.trim(); let user = User::try_get(username, &pool) .await .map_err(|_| LoginErrorKind::Unknown)?; let verifier = Argon2::default(); let hash = PasswordHash::new(&user.pwhash).map_err(|_| LoginErrorKind::Internal)?; match verifier.verify_password(pw.as_bytes(), &hash) { Ok(_) => { // log them in and set a session cookie auth.login(&user) .await .map_err(|_| LoginErrorKind::Internal)?; Ok(Redirect::to("/")) } _ => Err(LoginErrorKind::BadPassword.into()), } } pub async fn get_login() -> impl IntoResponse { LoginGet::default() } pub async fn get_logout() -> impl IntoResponse { LogoutGet } pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse { if auth.current_user.is_some() { auth.logout().await; } LogoutPost } #[cfg(test)] mod test { use std::time::Duration; use axum::body::Bytes; use axum_test::TestServer; use crate::{ db, signup::create_user, templates::{LoginGet, LogoutGet, LogoutPost}, }; async fn tserver() -> TestServer { let pool = db::get_pool().await; let secret = [0u8; 64]; tokio::time::sleep(Duration::from_secs(2)).await; let _user = create_user( "test_user", &Some("Test User".to_string()), &Some("mail@email".to_string()), "aaaa".as_bytes(), &pool, ) .await .unwrap(); let r = sqlx::query("select count(*) from witches") .fetch_one(&pool) .await; assert!(r.is_ok()); let app = crate::app(pool, &secret).await.into_make_service(); TestServer::new(app).unwrap() } #[tokio::test] async fn get_login() { let s = tserver().await; let resp = s.get("/login").await; let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); assert_eq!(body, LoginGet::default().to_string()); } #[tokio::test] async fn post_login_success() { let s = tserver().await; let form = "username=test_user&password=aaaa".to_string(); let bytes = form.as_bytes(); let body = Bytes::copy_from_slice(bytes); let resp = s .post("/login") .expect_failure() .content_type("application/x-www-form-urlencoded") .bytes(body) .await; assert_eq!(resp.status_code(), 303); } #[tokio::test] async fn post_login_bad_user() { let s = tserver().await; let form = "username=test_LOSER&password=aaaa".to_string(); let bytes = form.as_bytes(); let body = Bytes::copy_from_slice(bytes); let resp = s .post("/login") .expect_success() .content_type("application/x-www-form-urlencoded") .bytes(body) .await; assert_eq!(resp.status_code(), 200); } #[tokio::test] async fn get_logout() { let s = tserver().await; let resp = s.get("/logout").await; let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); assert_eq!(body, LogoutGet.to_string()); } #[tokio::test] async fn post_logout() { let s = tserver().await; let resp = s.post("/logout").await; resp.assert_status_ok(); let body = std::str::from_utf8(resp.bytes()).unwrap(); let default = LogoutPost.to_string(); assert_eq!(body, &default); } }