diff --git a/src/db.rs b/src/db.rs index 0571db2..8e078c9 100644 --- a/src/db.rs +++ b/src/db.rs @@ -126,14 +126,18 @@ pub async fn auth_layer( //-************************************************************************ #[cfg(test)] mod tests { + use tokio::runtime::Runtime; - #[tokio::test] - async fn it_migrates_the_db() { + #[test] + fn it_migrates_the_db() { + let rt = Runtime::new().unwrap(); let db = super::get_db_pool(); - let r = sqlx::query("select count(*) from users") - .fetch_one(&db) - .await; - assert!(r.is_ok()); + rt.block_on(async { + let r = sqlx::query("select count(*) from users") + .fetch_one(&db) + .await; + assert!(r.is_ok()); + }); } } @@ -382,223 +386,172 @@ mod session_store { use super::*; - fn test_store() -> SqliteSessionStore { - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - - let store = rt.block_on(async { - let store = SqliteSessionStore::new("sqlite::memory:") - .await - .expect("building a sqlite :memory: SqliteSessionStore"); - store - .migrate() - .await - .expect("migrating a brand new :memory: SqliteSessionStore"); - store - }); - dbg!("got the store"); + async fn test_store() -> SqliteSessionStore { + let store = SqliteSessionStore::new("sqlite::memory:") + .await + .expect("building a sqlite :memory: SqliteSessionStore"); + store + .migrate() + .await + .expect("migrating a brand new :memory: SqliteSessionStore"); store } - #[test] - fn creating_a_new_session_with_no_expiry() -> Result { - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - - let store = test_store(); + #[tokio::test] + async fn creating_a_new_session_with_no_expiry() -> Result { + let store = test_store().await; let mut session = Session::new(); - dbg!("new session"); session.insert("key", "value")?; let cloned = session.clone(); - rt.block_on(async { - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(session).await?.unwrap(); - let (id, expires, serialized, count): (String, Option, String, i64) = - sqlx::query_as("select id, expires, session, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + let (id, expires, serialized, count): (String, Option, String, i64) = + sqlx::query_as("select id, expires, session, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - assert_eq!(1, count); - assert_eq!(id, cloned.id()); - assert_eq!(expires, None); + assert_eq!(1, count); + assert_eq!(id, cloned.id()); + assert_eq!(expires, None); - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + let deserialized_session: Session = serde_json::from_str(&serialized)?; + assert_eq!(cloned.id(), deserialized_session.id()); + assert_eq!("value", &deserialized_session.get::("key").unwrap()); - let loaded_session = store.load_session(cookie_value).await?.unwrap(); - assert_eq!(cloned.id(), loaded_session.id()); - assert_eq!("value", &loaded_session.get::("key").unwrap()); + let loaded_session = store.load_session(cookie_value).await?.unwrap(); + assert_eq!(cloned.id(), loaded_session.id()); + assert_eq!("value", &loaded_session.get::("key").unwrap()); - assert!(!loaded_session.is_expired()); - Ok(()) - }) + assert!(!loaded_session.is_expired()); + Ok(()) } - #[test] - fn updating_a_session() -> Result { - let store = test_store(); + #[tokio::test] + async fn updating_a_session() -> Result { + let store = test_store().await; let mut session = Session::new(); let original_id = session.id().to_owned(); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); session.insert("key", "value")?; - rt.block_on(async { - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - session.insert("key", "other value")?; - assert_eq!(None, store.store_session(session).await?); + let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + session.insert("key", "other value")?; + assert_eq!(None, store.store_session(session).await?); - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.get::("key").unwrap(), "other value"); + let session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.get::("key").unwrap(), "other value"); - let (id, count): (String, i64) = - sqlx::query_as("select id, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + let (id, count): (String, i64) = + sqlx::query_as("select id, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - assert_eq!(1, count); - assert_eq!(original_id, id); + assert_eq!(1, count); + assert_eq!(original_id, id); - Ok(()) - }) + Ok(()) } - #[test] - fn updating_a_session_extending_expiry() -> Result { - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - - let store = test_store(); + #[tokio::test] + async fn updating_a_session_extending_expiry() -> Result { + let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(10)); let original_id = session.id().to_owned(); let original_expires = session.expiry().unwrap().clone(); - rt.block_on(async { - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &original_expires); - session.expire_in(Duration::from_secs(20)); - let new_expires = session.expiry().unwrap().clone(); - store.store_session(session).await?; + let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.expiry().unwrap(), &original_expires); + session.expire_in(Duration::from_secs(20)); + let new_expires = session.expiry().unwrap().clone(); + store.store_session(session).await?; - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &new_expires); + let session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.expiry().unwrap(), &new_expires); - let (id, expires, count): (String, i64, i64) = - sqlx::query_as("select id, expires, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + let (id, expires, count): (String, i64, i64) = + sqlx::query_as("select id, expires, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - assert_eq!(1, count); - assert_eq!(expires, new_expires.timestamp()); - assert_eq!(original_id, id); + assert_eq!(1, count); + assert_eq!(expires, new_expires.timestamp()); + assert_eq!(original_id, id); - Ok(()) - }) + Ok(()) } - #[test] - fn creating_a_new_session_with_expiry() -> Result { - let store = test_store(); + #[tokio::test] + async fn creating_a_new_session_with_expiry() -> Result { + let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(1)); session.insert("key", "value")?; let cloned = session.clone(); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); + let cookie_value = store.store_session(session).await?.unwrap(); - rt.block_on(async { - let cookie_value = store.store_session(session).await?.unwrap(); + let (id, expires, serialized, count): (String, Option, String, i64) = + sqlx::query_as("select id, expires, session, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - let (id, expires, serialized, count): (String, Option, String, i64) = - sqlx::query_as("select id, expires, session, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + assert_eq!(1, count); + assert_eq!(id, cloned.id()); + assert!(expires.unwrap() > Utc::now().timestamp()); - assert_eq!(1, count); - assert_eq!(id, cloned.id()); - assert!(expires.unwrap() > Utc::now().timestamp()); + let deserialized_session: Session = serde_json::from_str(&serialized)?; + assert_eq!(cloned.id(), deserialized_session.id()); + assert_eq!("value", &deserialized_session.get::("key").unwrap()); - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(cloned.id(), loaded_session.id()); + assert_eq!("value", &loaded_session.get::("key").unwrap()); - let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(cloned.id(), loaded_session.id()); - assert_eq!("value", &loaded_session.get::("key").unwrap()); + assert!(!loaded_session.is_expired()); - assert!(!loaded_session.is_expired()); + tokio::time::sleep(Duration::from_secs(1)).await; + assert_eq!(None, store.load_session(cookie_value).await?); - tokio::time::sleep(Duration::from_secs(1)).await; - assert_eq!(None, store.load_session(cookie_value).await?); - - Ok(()) - }) + Ok(()) } - #[test] - fn destroying_a_single_session() -> Result { - let store = test_store(); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); + #[tokio::test] + async fn destroying_a_single_session() -> Result { + let store = test_store().await; + for _ in 0..3i8 { + store.store_session(Session::new()).await?; + } - rt.block_on(async { - for _ in 0..3i8 { - store.store_session(Session::new()).await?; - } + let cookie = store.store_session(Session::new()).await?.unwrap(); + assert_eq!(4, store.count().await?); + let session = store.load_session(cookie.clone()).await?.unwrap(); + store.destroy_session(session.clone()).await.unwrap(); + assert_eq!(None, store.load_session(cookie).await?); + assert_eq!(3, store.count().await?); - let cookie = store.store_session(Session::new()).await?.unwrap(); - assert_eq!(4, store.count().await?); - let session = store.load_session(cookie.clone()).await?.unwrap(); - store.destroy_session(session.clone()).await.unwrap(); - assert_eq!(None, store.load_session(cookie).await?); - assert_eq!(3, store.count().await?); - - // // attempting to destroy the session again is not an error - // assert!(store.destroy_session(session).await.is_ok()); - Ok(()) - }) + // // attempting to destroy the session again is not an error + // assert!(store.destroy_session(session).await.is_ok()); + Ok(()) } - #[test] - fn clearing_the_whole_store() -> Result { - let store = test_store(); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); + #[tokio::test] + async fn clearing_the_whole_store() -> Result { + let store = test_store().await; + for _ in 0..3i8 { + store.store_session(Session::new()).await?; + } - rt.block_on(async { - for _ in 0..3i8 { - store.store_session(Session::new()).await?; - } + assert_eq!(3, store.count().await?); + store.clear_store().await.unwrap(); + assert_eq!(0, store.count().await?); - assert_eq!(3, store.count().await?); - store.clear_store().await.unwrap(); - assert_eq!(0, store.count().await?); - - Ok(()) - }) + Ok(()) } } } diff --git a/src/generic_handlers.rs b/src/generic_handlers.rs index b60bebe..15118c2 100644 --- a/src/generic_handlers.rs +++ b/src/generic_handlers.rs @@ -20,46 +20,33 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse { #[cfg(test)] mod test { - use std::future::IntoFuture; - use axum_test::TestServer; use tokio::runtime::Runtime; - use crate::db; + use crate::{get_db_pool, test_utils::server_with_pool}; #[test] fn slash_is_ok() { - let pool = db::get_db_pool(); - let secret = [0u8; 64]; + let db = get_db_pool(); + let rt = Runtime::new().unwrap(); - let app = rt - .block_on(crate::app(pool.clone(), &secret)) - .into_make_service(); - - let server = TestServer::new(app).unwrap(); - - rt.block_on(server.get("/").into_future()) - .assert_status_ok(); + rt.block_on(async { + let server = server_with_pool(&db).await; + server.get("/").await + }) + .assert_status_ok(); } #[test] fn not_found_is_303() { - let pool = db::get_db_pool(); - let secret = [0u8; 64]; - let rt = Runtime::new().unwrap(); - let app = rt - .block_on(crate::app(pool.clone(), &secret)) - .into_make_service(); - let server = TestServer::new(app).unwrap(); + let db = get_db_pool(); assert_eq!( - rt.block_on( - server - .get("/no-actual-route") - .expect_failure() - .into_future() - ) + rt.block_on(async { + let server = server_with_pool(&db).await; + server.get("/no-actual-route").expect_failure().await + }) .status_code(), 303 ); diff --git a/src/import_utils.rs b/src/import_utils.rs index 05c3575..f4e2b4c 100644 --- a/src/import_utils.rs +++ b/src/import_utils.rs @@ -154,13 +154,18 @@ async fn check_omega_exists(db_pool: &SqlitePool) -> bool { #[cfg(test)] mod test { + use tokio::runtime::Runtime; + use super::*; - #[tokio::test] - async fn ensure_omega_user() { + #[test] + fn ensure_omega_user() { let p = crate::db::get_db_pool(); - assert!(!check_omega_exists(&p).await); - ensure_omega(&p).await; - assert!(check_omega_exists(&p).await); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + assert!(!check_omega_exists(&p).await); + ensure_omega(&p).await; + assert!(check_omega_exists(&p).await); + }); } } diff --git a/src/lib.rs b/src/lib.rs index f1b65b8..c1e891d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +use axum::routing::IntoMakeService; +use sqlx::SqlitePool; #[macro_use] extern crate justerror; @@ -9,10 +11,11 @@ pub mod test_utils; pub use db::get_db_pool; pub use db_id::DbId; pub mod import_utils; - pub use users::User; pub use watches::{ShowKind, Watch, WatchQuest}; +pub type WWRouter = axum::Router; + // everything else is private to the crate mod db; mod db_id; @@ -32,11 +35,8 @@ use watches::templates::*; type AuthContext = axum_login::extractors::AuthContext>; /// Returns the router to be used as a service or test object, you do you. -pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Router { +pub async fn app(db_pool: sqlx::SqlitePool, secret: &[u8]) -> IntoMakeService { use axum::{middleware, routing::get}; - let session_layer = db::session_layer(db_pool.clone(), session_secret).await; - let auth_layer = db::auth_layer(db_pool.clone(), session_secret).await; - // don't bother bringing handlers into the whole crate namespace use generic_handlers::{handle_slash, handle_slash_redir}; use login::{get_login, get_logout, post_login, post_logout}; @@ -46,6 +46,12 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout post_add_watch_quest, }; + let (session_layer, auth_layer) = { + let session_layer = db::session_layer(db_pool.clone(), secret).await; + let auth_layer = db::auth_layer(db_pool.clone(), secret).await; + (session_layer, auth_layer) + }; + axum::Router::new() .route("/", get(handle_slash).post(handle_slash)) .route("/signup", get(get_create_user).post(post_create_user)) @@ -69,6 +75,7 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout .layer(auth_layer) .layer(session_layer) .with_state(db_pool) + .into_make_service() } //-************************************************************************ diff --git a/src/login.rs b/src/login.rs index 4b01011..87eb549 100644 --- a/src/login.rs +++ b/src/login.rs @@ -112,8 +112,9 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse { #[cfg(test)] mod test { use crate::{ + get_db_pool, templates::{LoginPage, LogoutPage, LogoutSuccessPage, MainPage}, - test_utils::{get_test_user, massage, server, FORM_CONTENT_TYPE}, + test_utils::{get_test_user, massage, server_with_pool, FORM_CONTENT_TYPE}, }; const LOGIN_FORM: &str = "username=test_user&password=a"; @@ -125,8 +126,9 @@ mod test { .build() .unwrap(); - let s = server(); + let db = get_db_pool(); rt.block_on(async { + let s = server_with_pool(&db).await; let resp = s.get("/login").await; let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); assert_eq!(body, LoginPage::default().to_string()); @@ -135,7 +137,6 @@ mod test { #[test] fn post_login_success() { - let s = server(); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -143,7 +144,9 @@ mod test { let body = massage(LOGIN_FORM); + let db = get_db_pool(); rt.block_on(async { + let s = server_with_pool(&db).await; let resp = s .post("/login") .expect_failure() @@ -156,7 +159,6 @@ mod test { #[test] fn post_login_bad_user() { - let s = server(); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -165,7 +167,9 @@ mod test { let form = "username=test_LOSER&password=aaaa"; let body = massage(form); + let db = get_db_pool(); rt.block_on(async { + let s = server_with_pool(&db).await; let resp = s .post("/login") .expect_success() @@ -178,7 +182,6 @@ mod test { #[test] fn post_login_bad_password() { - let s = server(); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -187,7 +190,9 @@ mod test { let form = "username=test_user&password=bbbb"; let body = massage(form); + let db = get_db_pool(); rt.block_on(async { + let s = server_with_pool(&db).await; let resp = s .post("/login") .expect_success() @@ -200,13 +205,14 @@ mod test { #[test] fn get_logout() { - let s = server(); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); + let db = get_db_pool(); rt.block_on(async { + let s = server_with_pool(&db).await; let resp = s.get("/logout").await; let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); assert_eq!(body, LogoutPage.to_string()); @@ -215,13 +221,14 @@ mod test { #[test] fn post_logout_not_logged_in() { - let s = server(); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); + let db = get_db_pool(); rt.block_on(async { + let s = server_with_pool(&db).await; let resp = s.post("/logout").await; resp.assert_status_ok(); let body = std::str::from_utf8(resp.bytes()).unwrap(); @@ -232,14 +239,15 @@ mod test { #[test] fn post_logout_logged_in() { - let s = server(); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); // log in and prove it + let db = get_db_pool(); rt.block_on(async { + let s = server_with_pool(&db).await; let body = massage(LOGIN_FORM); let resp = s .post("/login") diff --git a/src/main.rs b/src/main.rs index 4548be1..1d9e851 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ use std::net::SocketAddr; use rand::{thread_rng, RngCore}; -use sqlx::SqlitePool; use tokio::signal; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use what2watch::get_db_pool; @@ -22,11 +21,6 @@ fn main() { .build() .unwrap(); - rt.block_on(runner(&pool)); - rt.block_on(pool.close()); -} - -async fn runner(pool: &SqlitePool) { let secret = { let mut bytes = [0u8; 64]; let mut rng = thread_rng(); @@ -34,16 +28,20 @@ async fn runner(pool: &SqlitePool) { bytes }; - let app = what2watch::app(pool.clone(), &secret).await; + let app = rt.block_on(what2watch::app(pool.clone(), &secret)); - let addr: SocketAddr = ([0, 0, 0, 0], 3000).into(); - tracing::debug!("binding to {addr:?}"); + rt.block_on(async { + let addr: SocketAddr = ([0, 0, 0, 0], 3000).into(); + tracing::debug!("binding to {addr:?}"); - axum::Server::bind(&addr) - .serve(app.into_make_service()) - .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap_or_default(); + axum::Server::bind(&addr) + .serve(app) + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap_or_default(); + }); + + rt.block_on(pool.close()); } async fn shutdown_signal() { @@ -66,5 +64,5 @@ async fn shutdown_signal() { _ = terminate => {}, } - println!("signal received, starting graceful shutdown"); + println!(" signal received, starting graceful shutdown"); } diff --git a/src/signup.rs b/src/signup.rs index 2b9c6db..7e217a3 100644 --- a/src/signup.rs +++ b/src/signup.rs @@ -211,62 +211,71 @@ pub(crate) async fn create_user( #[cfg(test)] mod test { use axum::http::StatusCode; + use tokio::runtime::Runtime; use crate::{ db::get_db_pool, templates::{SignupPage, SignupSuccessPage}, - test_utils::{get_test_user, insert_user, massage, server_with_pool, FORM_CONTENT_TYPE}, + test_utils::{get_test_user, massage, server_with_pool, FORM_CONTENT_TYPE}, User, }; - const GOOD_FORM: &str = "username=test_user&displayname=Test+User&password=aaaa&pw_verify=aaaa"; + const GOOD_FORM: &str = "username=good_user&displayname=Test+User&password=aaaa&pw_verify=aaaa"; - #[tokio::test] - async fn post_create_user() { + #[test] + fn post_create_user() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(GOOD_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(GOOD_FORM); - let resp = server - .post("/signup") - .expect_failure() // 303 is "failure" - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + let resp = server + .post("/signup") + .expect_failure() // 303 is "failure" + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - assert_eq!(StatusCode::SEE_OTHER, resp.status_code()); + assert_eq!(StatusCode::SEE_OTHER, resp.status_code()); - // get the new user from the db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_ok()); + // get the new user from the db + let user = User::try_get("good_user", &pool).await; + assert!(user.is_ok()); + }); } - #[tokio::test] - async fn get_create_user() { + #[test] + fn get_create_user() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; - let resp = server.get("/signup").await; - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = SignupPage::default().to_string(); - assert_eq!(&expected, body); + let resp = server.get("/signup").await; + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = SignupPage::default().to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn handle_signup_success() { + #[test] + fn handle_signup_success() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; - let user = get_test_user(); - insert_user(&user, &pool).await; - let id = user.id.0.to_string(); + let user = get_test_user(); + let id = user.id.0.to_string(); - let path = format!("/signup_success/{id}"); + let path = format!("/signup_success/{id}"); - let resp = server.get(&path).expect_success().await; - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = SignupSuccessPage(user).to_string(); - assert_eq!(&expected, body); + let resp = server.get(&path).expect_success().await; + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = SignupSuccessPage(user).to_string(); + assert_eq!(&expected, body); + }); } //-************************************************************************ @@ -278,223 +287,253 @@ mod test { // various ways to fuck up signup const PASSWORD_MISMATCH_FORM: &str = - "username=test_user&displayname=Test+User&password=aaaa&pw_verify=bbbb"; + "username=bad_user&displayname=Test+User&password=aaaa&pw_verify=bbbb"; const PASSWORD_SHORT_FORM: &str = - "username=test_user&displayname=Test+User&password=a&pw_verify=a"; - const PASSWORD_LONG_FORM: &str = "username=test_user&displayname=Test+User&password=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda&pw_verify=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda"; + "username=bad_user&displayname=Test+User&password=a&pw_verify=a"; + const PASSWORD_LONG_FORM: &str = "username=bad_user&displayname=Test+User&password=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda&pw_verify=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda"; const USERNAME_SHORT_FORM: &str = "username=&displayname=Test+User&password=aaaa&pw_verify=aaaa"; const USERNAME_LONG_FORM: &str = - "username=test_user12345678901234567890&displayname=Test+User&password=aaaa&pw_verify=aaaa"; + "username=bad_user12345678901234567890&displayname=Test+User&password=aaaa&pw_verify=aaaa"; const DISPLAYNAME_LONG_FORM: &str = "username=test_user&displayname=Since+time+immemorial%2C+display+names+have+been+subject+to+a+number+of+conventions%2C+restrictions%2C+usages%2C+and+even+incentives.+Have+we+finally+gone+too+far%3F+In+this+essay%2C+&password=aaaa&pw_verify=aaaa"; - #[tokio::test] - async fn password_mismatch() { + #[test] + fn password_mismatch() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(PASSWORD_MISMATCH_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(PASSWORD_MISMATCH_FORM); - let resp = server - .post("/signup") - // failure to sign up is not failure to submit the request - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + let resp = server + .post("/signup") + // failure to sign up is not failure to submit the request + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - // no user in db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_err()); + // no user in db + let user = User::try_get("bad_user", &pool).await; + assert!(user.is_err()); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string(); - assert_eq!(&expected, body); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn password_short() { + #[test] + fn password_short() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(PASSWORD_SHORT_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(PASSWORD_SHORT_FORM); - let resp = server - .post("/signup") - // failure to sign up is not failure to submit the request - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + let resp = server + .post("/signup") + // failure to sign up is not failure to submit the request + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - // no user in db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_err()); + // no user in db + let user = User::try_get("bad_user", &pool).await; + assert!(user.is_err()); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); - assert_eq!(&expected, body); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn password_long() { + #[test] + fn password_long() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(PASSWORD_LONG_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(PASSWORD_LONG_FORM); - let resp = server - .post("/signup") - // failure to sign up is not failure to submit the request - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + let resp = server + .post("/signup") + // failure to sign up is not failure to submit the request + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - // no user in db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_err()); + // no user in db + let user = User::try_get("bad_user", &pool).await; + assert!(user.is_err()); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); - assert_eq!(&expected, body); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn multibyte_password_too_short() { + #[test] + fn multibyte_password_too_short() { let pw = "🤡"; // min length is 4 assert_eq!(pw.len(), 4); let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let form = - format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}"); - let body = massage(&form); + let rt = Runtime::new().unwrap(); - let resp = server - .post("/signup") - // failure to sign up is not failure to submit the request - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + rt.block_on(async { + let server = server_with_pool(&pool).await; + let form = format!( + "username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}" + ); + let body = massage(&form); - // no user in db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_err()); + let resp = server + .post("/signup") + // failure to sign up is not failure to submit the request + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); - assert_eq!(&expected, body); + // no user in db + let user = User::try_get("bad_user", &pool).await; + assert!(user.is_err()); + + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn username_short() { + #[test] + fn username_short() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(USERNAME_SHORT_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(USERNAME_SHORT_FORM); - let resp = server - .post("/signup") - // failure to sign up is not failure to submit the request - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + let resp = server + .post("/signup") + // failure to sign up is not failure to submit the request + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - // no user in db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_err()); + // no user in db + let user = User::try_get("bad_user", &pool).await; + assert!(user.is_err()); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); - assert_eq!(&expected, body); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn username_long() { + #[test] + fn username_long() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(USERNAME_LONG_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(USERNAME_LONG_FORM); - let resp = server - .post("/signup") - // failure to sign up is not failure to submit the request - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + let resp = server + .post("/signup") + // failure to sign up is not failure to submit the request + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - // no user in db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_err()); + // no user in db + let user = User::try_get("bad_user", &pool).await; + assert!(user.is_err()); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); - assert_eq!(&expected, body); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn username_duplicate() { + #[test] + fn username_duplicate() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(GOOD_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(GOOD_FORM); - let _resp = server - .post("/signup") - .expect_failure() // 303 is "failure" - .bytes(body.clone()) - .content_type(FORM_CONTENT_TYPE) - .await; + let _resp = server + .post("/signup") + .expect_failure() // 303 is "failure" + .bytes(body.clone()) + .content_type(FORM_CONTENT_TYPE) + .await; - // get the new user from the db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_ok()); + // get the new user from the db + let user = User::try_get("good_user", &pool).await; + assert!(user.is_ok()); - // now try again - let resp = server - .post("/signup") - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + // now try again + let resp = server + .post("/signup") + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - assert_eq!(resp.status_code(), StatusCode::OK); - let expected = CreateUserError(CreateUserErrorKind::AlreadyExists).to_string(); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - assert_eq!(&expected, body); + assert_eq!(resp.status_code(), StatusCode::OK); + let expected = CreateUserError(CreateUserErrorKind::AlreadyExists).to_string(); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn displayname_long() { + #[test] + fn displayname_long() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; - let body = massage(DISPLAYNAME_LONG_FORM); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let server = server_with_pool(&pool).await; + let body = massage(DISPLAYNAME_LONG_FORM); - let resp = server - .post("/signup") - // failure to sign up is not failure to submit the request - .expect_success() - .bytes(body) - .content_type(FORM_CONTENT_TYPE) - .await; + let resp = server + .post("/signup") + // failure to sign up is not failure to submit the request + .expect_success() + .bytes(body) + .content_type(FORM_CONTENT_TYPE) + .await; - // no user in db - let user = User::try_get("test_user", &pool).await; - assert!(user.is_err()); + // no user in db + let user = User::try_get("bad_user", &pool).await; + assert!(user.is_err()); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string(); - assert_eq!(&expected, body); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string(); + assert_eq!(&expected, body); + }); } - #[tokio::test] - async fn handle_signup_success() { + #[test] + fn handle_signup_success() { let pool = get_db_pool(); - let server = server_with_pool(&pool).await; + let rt = Runtime::new().unwrap(); - let path = format!("/signup_success/nope"); + rt.block_on(async { + let server = server_with_pool(&pool).await; - let resp = server.get(&path).expect_failure().await; - assert_eq!(resp.status_code(), StatusCode::SEE_OTHER); + let path = format!("/signup_success/nope"); + + let resp = server.get(&path).expect_failure().await; + assert_eq!(resp.status_code(), StatusCode::SEE_OTHER); + }); } } } diff --git a/src/test_utils.rs b/src/test_utils.rs index b503301..d758522 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -17,51 +17,19 @@ pub fn get_test_user() -> User { } } -pub fn server() -> TestServer { - let pool = crate::db::get_db_pool(); - let secret = [0u8; 64]; - - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - - let user = get_test_user(); - rt.block_on(async { - sqlx::query(crate::signup::CREATE_QUERY) - .bind(user.id) - .bind(&user.username) - .bind(&user.displayname) - .bind(&user.email) - .bind(&user.pwhash) - .execute(&pool) - .await - .unwrap(); - - let r = sqlx::query("select count(*) from users") - .fetch_one(&pool) - .await; - assert!(r.is_ok()); - - let app = crate::app(pool, &secret).await.into_make_service(); - - let config = TestServerConfig { - save_cookies: true, - ..Default::default() - }; - TestServer::new_with_config(app, config).unwrap() - }) -} - pub async fn server_with_pool(pool: &SqlitePool) -> TestServer { let secret = [0u8; 64]; - let r = sqlx::query("select count(*) from users") - .fetch_one(pool) - .await; - assert!(r.is_ok()); + let user = get_test_user(); - let app = crate::app(pool.clone(), &secret).await.into_make_service(); + insert_user(&user, pool).await; + let r: i32 = sqlx::query_scalar("select count(*) from users") + .fetch_one(pool) + .await + .unwrap_or_default(); + assert!(r == 1); + + let app = crate::app(pool.clone(), &secret).await; let config = TestServerConfig { save_cookies: true,