use axum::{ http::StatusCode, response::{IntoResponse, Redirect, Response}, Form, }; use serde::{Deserialize, Serialize}; use crate::{ auth::{AuthError, AuthErrorKind, Credentials}, AuthSession, LoginPage, LogoutPage, LogoutSuccessPage, }; //-************************************************************************ // Login error and success types //-************************************************************************ impl IntoResponse for AuthError { fn into_response(self) -> Response { match self.0 { AuthErrorKind::Internal => ( StatusCode::INTERNAL_SERVER_ERROR, "An unknown error occurred; you cursed, brah?", ) .into_response(), AuthErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(), } } } // for receiving form submissions #[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, Eq)] pub struct LoginPostForm { pub username: String, pub password: String, pub destination: Option, } impl From for Credentials { fn from(form: LoginPostForm) -> Self { Self { username: form.username, password: form.password, } } } //-************************************************************************ // Login handlers //-************************************************************************ /// Handle login queries #[axum::debug_handler] pub async fn post_login( mut auth: AuthSession, Form(mut login_form): Form, ) -> Result { let dest = login_form.destination.take(); let user = match auth.authenticate(login_form.clone().into()).await { Ok(Some(user)) => user, Ok(None) => return Ok(LoginPage::default().into_response()), Err(_) => return Err(AuthErrorKind::Internal.into()), }; if auth.login(&user).await.is_err() { return Err(AuthErrorKind::Internal.into()); } if let Some(ref next) = dest { Ok(Redirect::to(next).into_response()) } else { Ok(Redirect::to("/").into_response()) } } pub async fn get_login() -> impl IntoResponse { LoginPage::default() } pub async fn get_logout() -> impl IntoResponse { LogoutPage } pub async fn post_logout(mut auth: AuthSession) -> impl IntoResponse { match auth.logout().await { Ok(_) => LogoutSuccessPage.into_response(), Err(e) => { tracing::debug!("{e}"); let e: AuthError = AuthErrorKind::Internal.into(); e.into_response() } } } //-************************************************************************ // tests //-************************************************************************ #[cfg(test)] mod test { use crate::{ get_db_pool, templates::{LoginPage, LogoutPage, LogoutSuccessPage, MainPage}, test_utils::{massage, server_with_pool, FORM_CONTENT_TYPE}, User, }; const LOGIN_FORM: &str = "username=test_user&password=a"; #[test] fn get_login() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let db = get_db_pool(); rt.block_on(async { let s = server_with_pool(&db).await; let resp = s.get("/login").await; let body = std::str::from_utf8(resp.as_bytes()).unwrap().to_string(); assert_eq!(body, LoginPage::default().to_string()); }) } #[test] fn post_login_success() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let body = massage(LOGIN_FORM); let db = get_db_pool(); rt.block_on(async { let s = server_with_pool(&db).await; let resp = s .post("/login") .expect_failure() .content_type(FORM_CONTENT_TYPE) .bytes(body) .await; assert_eq!(resp.status_code(), 303); }) } #[test] fn post_login_bad_user() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let form = "username=test_LOSER&password=aaaa"; let body = massage(form); let db = get_db_pool(); rt.block_on(async { let s = server_with_pool(&db).await; let resp = s .post("/login") .expect_success() .content_type(FORM_CONTENT_TYPE) .bytes(body) .await; assert_eq!(resp.status_code(), 200); }) } #[test] fn post_login_bad_password() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let form = "username=test_user&password=bbbb"; let body = massage(form); let db = get_db_pool(); rt.block_on(async { let s = server_with_pool(&db).await; let resp = s .post("/login") .expect_success() .content_type(FORM_CONTENT_TYPE) .bytes(body) .await; assert_eq!(resp.status_code(), 200); }) } #[test] fn get_logout() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let db = get_db_pool(); rt.block_on(async { let s = server_with_pool(&db).await; let resp = s.get("/logout").await; let body = std::str::from_utf8(resp.as_bytes()).unwrap().to_string(); assert_eq!(body, LogoutPage.to_string()); }) } #[test] fn post_logout_not_logged_in() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let db = get_db_pool(); rt.block_on(async { let s = server_with_pool(&db).await; let resp = s.post("/logout").await; resp.assert_status_ok(); let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let default = LogoutSuccessPage.to_string(); assert_eq!(body, &default); }) } #[test] fn post_logout_logged_in() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); // log in and prove it let db = get_db_pool(); rt.block_on(async { let s = server_with_pool(&db).await; let body = massage(LOGIN_FORM); let resp = s .post("/login") .expect_failure() .content_type(FORM_CONTENT_TYPE) .bytes(body) .await; assert_eq!(resp.status_code(), 303); let user = User::try_get("test_user", &db).await.unwrap(); let logged_in = MainPage { user }.to_string(); let main_page = s.get("/").await; let body = std::str::from_utf8(main_page.as_bytes()).unwrap(); assert_eq!(&logged_in, body); let resp = s.post("/logout").await; let body = std::str::from_utf8(resp.as_bytes()).unwrap(); let default = LogoutSuccessPage.to_string(); assert_eq!(body, &default); }) } }