From 48308aa169de74b1a61ab72869ad75437f6309fb Mon Sep 17 00:00:00 2001 From: Joe Ardent Date: Tue, 18 Jul 2023 17:37:24 -0700 Subject: [PATCH] make db acquire sync; tests fail but app runs correctly --- src/bin/import_omega.rs | 53 ++++--- src/bin/import_users.rs | 47 +++--- src/db.rs | 313 ++++++++++++++++++++++++---------------- src/generic_handlers.rs | 41 ++++-- src/import_utils.rs | 2 +- src/login.rs | 165 +++++++++++++-------- src/main.rs | 18 ++- src/signup.rs | 24 +-- src/test_utils.rs | 49 ++++--- 9 files changed, 437 insertions(+), 275 deletions(-) diff --git a/src/bin/import_omega.rs b/src/bin/import_omega.rs index 3e20746..065cc35 100644 --- a/src/bin/import_omega.rs +++ b/src/bin/import_omega.rs @@ -10,32 +10,47 @@ struct Cli { pub db_path: OsString, } -#[tokio::main] -async fn main() { +fn main() { let cli = Cli::parse(); let path = cli.db_path; - let opts = SqliteConnectOptions::new().filename(&path).read_only(true); - let movie_db = SqlitePoolOptions::new() - .idle_timeout(Duration::from_secs(90)) - .connect_with(opts) - .await - .expect("could not open movies db"); + let opts = SqliteConnectOptions::new().filename(path).read_only(true); + let movie_db = { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); - let w2w_db = get_db_pool().await; + rt.block_on( + SqlitePoolOptions::new() + .idle_timeout(Duration::from_secs(90)) + .connect_with(opts), + ) + .expect("could not open movies db") + }; + + let w2w_db = get_db_pool(); let start = std::time::Instant::now(); - add_omega_watches(&w2w_db, &movie_db).await.unwrap(); + let rows = { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + add_omega_watches(&w2w_db, &movie_db).await.unwrap(); + + let rows: i32 = sqlx::query_scalar("select count(*) from watches") + .fetch_one(&w2w_db) + .await + .unwrap(); + w2w_db.close().await; + rows + }) + }; + let end = std::time::Instant::now(); - let dur = (end - start).as_secs_f32(); - - let rows: i32 = sqlx::query_scalar("select count(*) from watches") - .fetch_one(&w2w_db) - .await - .unwrap(); - println!("Added {rows} movies in {dur} seconds"); - - w2w_db.close().await; } diff --git a/src/bin/import_users.rs b/src/bin/import_users.rs index b65365b..4f0b8b0 100644 --- a/src/bin/import_users.rs +++ b/src/bin/import_users.rs @@ -7,7 +7,7 @@ use sqlx::{ sqlite::{SqliteConnectOptions, SqlitePoolOptions}, SqlitePool, }; -use tokio::task::JoinSet; +use tokio::{runtime, task::JoinSet}; use tokio_retry::Retry; use what2watch::{ get_db_pool, @@ -15,8 +15,7 @@ use what2watch::{ DbId, User, WatchQuest, }; -#[tokio::main] -async fn main() { +fn main() { let cli = Cli::parse(); let path = cli.db_path; let num_users = cli.users; @@ -31,29 +30,39 @@ async fn main() { let words: Vec<&str> = words.split('\n').collect(); let opts = SqliteConnectOptions::new().filename(&path).read_only(true); - let movie_db = SqlitePoolOptions::new() - .idle_timeout(Duration::from_secs(3)) - .connect_with(opts) - .await - .expect("could not open movies db"); - let w2w_db = get_db_pool().await; - let users = &gen_users(num_users, &words, &w2w_db).await; - let movies = &add_omega_watches(&w2w_db, &movie_db).await.unwrap(); + let rt = runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let movie_db = rt + .block_on( + SqlitePoolOptions::new() + .idle_timeout(Duration::from_secs(3)) + .connect_with(opts), + ) + .expect("could not open movies db"); + let w2w_db = get_db_pool(); + + let users = &rt.block_on(gen_users(num_users, &words, &w2w_db)); + + let movies = &rt.block_on(add_omega_watches(&w2w_db, &movie_db)).unwrap(); let rng = &mut thread_rng(); let normal = Normal::new(mpu, mpu / 10.0).unwrap(); let start = std::time::Instant::now(); - for &user in users { - add_quests(user, movies, &w2w_db, rng, normal).await; - } - let end = std::time::Instant::now(); - let rows: i32 = sqlx::query_scalar("select count(*) from watch_quests") - .fetch_one(&w2w_db) - .await + rt.block_on(async { + for &user in users { + add_quests(user, movies, &w2w_db, rng, normal).await; + } + }); + let rows: i32 = rt + .block_on(sqlx::query_scalar("select count(*) from watch_quests").fetch_one(&w2w_db)) .unwrap(); - w2w_db.close().await; + rt.block_on(w2w_db.close()); + let end = std::time::Instant::now(); let dur = (end - start).as_secs_f32(); println!("Added {rows} quests in {dur} seconds"); } diff --git a/src/db.rs b/src/db.rs index dfcf2bb..0571db2 100644 --- a/src/db.rs +++ b/src/db.rs @@ -19,7 +19,7 @@ const MIN_CONNS: u32 = 5; const TIMEOUT: u64 = 11; const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64); -pub async fn get_db_pool() -> SqlitePool { +pub fn get_db_pool() -> SqlitePool { let db_filename = { std::env::var("DATABASE_FILE").unwrap_or_else(|_| { #[cfg(not(test))] @@ -57,25 +57,34 @@ pub async fn get_db_pool() -> SqlitePool { .min_connections(MIN_CONNS) .idle_timeout(Some(Duration::from_secs(30))) .max_lifetime(Some(Duration::from_secs(3600))) - .connect_with(conn_opts) - .await - .expect("can't connect to database"); + .connect_with(conn_opts); - // let the filesystem settle before trying anything - // possibly not effective? - tokio::time::sleep(Duration::from_millis(500)).await; + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let pool = { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(pool).unwrap() + }; + + let _con = rt.block_on(pool.acquire()).unwrap(); { - let mut m = Migrator::new(std::path::Path::new("./migrations")) - .await - .expect("Should be able to read the migration directory."); - - let m = m.set_locking(true); - - m.run(&pool) - .await - .expect("Should be able to run the migration."); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + let m = Migrator::new(std::path::Path::new("./migrations")); + let mut m = rt.block_on(m).unwrap(); + let m = m.set_locking(false); + rt.block_on(m.run(&pool)).unwrap(); tracing::info!("Ran migrations"); } @@ -120,7 +129,7 @@ mod tests { #[tokio::test] async fn it_migrates_the_db() { - let db = super::get_db_pool().await; + let db = super::get_db_pool(); let r = sqlx::query("select count(*) from users") .fetch_one(&db) .await; @@ -373,169 +382,223 @@ mod session_store { use super::*; - async fn test_store() -> SqliteSessionStore { - let store = SqliteSessionStore::new("sqlite::memory:") - .await - .expect("building a sqlite :memory: SqliteSessionStore"); - store - .migrate() - .await - .expect("migrating a brand new :memory: SqliteSessionStore"); + fn test_store() -> SqliteSessionStore { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let store = rt.block_on(async { + let store = SqliteSessionStore::new("sqlite::memory:") + .await + .expect("building a sqlite :memory: SqliteSessionStore"); + store + .migrate() + .await + .expect("migrating a brand new :memory: SqliteSessionStore"); + store + }); + dbg!("got the store"); store } - #[tokio::test] - async fn creating_a_new_session_with_no_expiry() -> Result { - let store = test_store().await; + #[test] + fn creating_a_new_session_with_no_expiry() -> Result { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let store = test_store(); let mut session = Session::new(); + dbg!("new session"); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); - let (id, expires, serialized, count): (String, Option, String, i64) = - sqlx::query_as("select id, expires, session, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + rt.block_on(async { + let cookie_value = store.store_session(session).await?.unwrap(); - assert_eq!(1, count); - assert_eq!(id, cloned.id()); - assert_eq!(expires, None); + let (id, expires, serialized, count): (String, Option, String, i64) = + sqlx::query_as("select id, expires, session, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + assert_eq!(1, count); + assert_eq!(id, cloned.id()); + assert_eq!(expires, None); - let loaded_session = store.load_session(cookie_value).await?.unwrap(); - assert_eq!(cloned.id(), loaded_session.id()); - assert_eq!("value", &loaded_session.get::("key").unwrap()); + let deserialized_session: Session = serde_json::from_str(&serialized)?; + assert_eq!(cloned.id(), deserialized_session.id()); + assert_eq!("value", &deserialized_session.get::("key").unwrap()); - assert!(!loaded_session.is_expired()); - Ok(()) + let loaded_session = store.load_session(cookie_value).await?.unwrap(); + assert_eq!(cloned.id(), loaded_session.id()); + assert_eq!("value", &loaded_session.get::("key").unwrap()); + + assert!(!loaded_session.is_expired()); + Ok(()) + }) } - #[tokio::test] - async fn updating_a_session() -> Result { - let store = test_store().await; + #[test] + fn updating_a_session() -> Result { + let store = test_store(); let mut session = Session::new(); let original_id = session.id().to_owned(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); session.insert("key", "value")?; - let cookie_value = store.store_session(session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - session.insert("key", "other value")?; - assert_eq!(None, store.store_session(session).await?); + rt.block_on(async { + let cookie_value = store.store_session(session).await?.unwrap(); - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.get::("key").unwrap(), "other value"); + let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + session.insert("key", "other value")?; + assert_eq!(None, store.store_session(session).await?); - let (id, count): (String, i64) = - sqlx::query_as("select id, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + let session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.get::("key").unwrap(), "other value"); - assert_eq!(1, count); - assert_eq!(original_id, id); + let (id, count): (String, i64) = + sqlx::query_as("select id, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - Ok(()) + assert_eq!(1, count); + assert_eq!(original_id, id); + + Ok(()) + }) } - #[tokio::test] - async fn updating_a_session_extending_expiry() -> Result { - let store = test_store().await; + #[test] + fn updating_a_session_extending_expiry() -> Result { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let store = test_store(); let mut session = Session::new(); session.expire_in(Duration::from_secs(10)); let original_id = session.id().to_owned(); let original_expires = session.expiry().unwrap().clone(); - let cookie_value = store.store_session(session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &original_expires); - session.expire_in(Duration::from_secs(20)); - let new_expires = session.expiry().unwrap().clone(); - store.store_session(session).await?; + rt.block_on(async { + let cookie_value = store.store_session(session).await?.unwrap(); - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &new_expires); + let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.expiry().unwrap(), &original_expires); + session.expire_in(Duration::from_secs(20)); + let new_expires = session.expiry().unwrap().clone(); + store.store_session(session).await?; - let (id, expires, count): (String, i64, i64) = - sqlx::query_as("select id, expires, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + let session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.expiry().unwrap(), &new_expires); - assert_eq!(1, count); - assert_eq!(expires, new_expires.timestamp()); - assert_eq!(original_id, id); + let (id, expires, count): (String, i64, i64) = + sqlx::query_as("select id, expires, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - Ok(()) + assert_eq!(1, count); + assert_eq!(expires, new_expires.timestamp()); + assert_eq!(original_id, id); + + Ok(()) + }) } - #[tokio::test] - async fn creating_a_new_session_with_expiry() -> Result { - let store = test_store().await; + #[test] + fn creating_a_new_session_with_expiry() -> Result { + let store = test_store(); let mut session = Session::new(); session.expire_in(Duration::from_secs(1)); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); - let (id, expires, serialized, count): (String, Option, String, i64) = - sqlx::query_as("select id, expires, session, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; + rt.block_on(async { + let cookie_value = store.store_session(session).await?.unwrap(); - assert_eq!(1, count); - assert_eq!(id, cloned.id()); - assert!(expires.unwrap() > Utc::now().timestamp()); + let (id, expires, serialized, count): (String, Option, String, i64) = + sqlx::query_as("select id, expires, session, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + assert_eq!(1, count); + assert_eq!(id, cloned.id()); + assert!(expires.unwrap() > Utc::now().timestamp()); - let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(cloned.id(), loaded_session.id()); - assert_eq!("value", &loaded_session.get::("key").unwrap()); + let deserialized_session: Session = serde_json::from_str(&serialized)?; + assert_eq!(cloned.id(), deserialized_session.id()); + assert_eq!("value", &deserialized_session.get::("key").unwrap()); - assert!(!loaded_session.is_expired()); + let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(cloned.id(), loaded_session.id()); + assert_eq!("value", &loaded_session.get::("key").unwrap()); - tokio::time::sleep(Duration::from_secs(1)).await; - assert_eq!(None, store.load_session(cookie_value).await?); + assert!(!loaded_session.is_expired()); - Ok(()) + tokio::time::sleep(Duration::from_secs(1)).await; + assert_eq!(None, store.load_session(cookie_value).await?); + + Ok(()) + }) } - #[tokio::test] - async fn destroying_a_single_session() -> Result { - let store = test_store().await; - for _ in 0..3i8 { - store.store_session(Session::new()).await?; - } + #[test] + fn destroying_a_single_session() -> Result { + let store = test_store(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); - let cookie = store.store_session(Session::new()).await?.unwrap(); - assert_eq!(4, store.count().await?); - let session = store.load_session(cookie.clone()).await?.unwrap(); - store.destroy_session(session.clone()).await.unwrap(); - assert_eq!(None, store.load_session(cookie).await?); - assert_eq!(3, store.count().await?); + rt.block_on(async { + for _ in 0..3i8 { + store.store_session(Session::new()).await?; + } - // // attempting to destroy the session again is not an error - // assert!(store.destroy_session(session).await.is_ok()); - Ok(()) + let cookie = store.store_session(Session::new()).await?.unwrap(); + assert_eq!(4, store.count().await?); + let session = store.load_session(cookie.clone()).await?.unwrap(); + store.destroy_session(session.clone()).await.unwrap(); + assert_eq!(None, store.load_session(cookie).await?); + assert_eq!(3, store.count().await?); + + // // attempting to destroy the session again is not an error + // assert!(store.destroy_session(session).await.is_ok()); + Ok(()) + }) } - #[tokio::test] - async fn clearing_the_whole_store() -> Result { - let store = test_store().await; - for _ in 0..3i8 { - store.store_session(Session::new()).await?; - } + #[test] + fn clearing_the_whole_store() -> Result { + let store = test_store(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); - assert_eq!(3, store.count().await?); - store.clear_store().await.unwrap(); - assert_eq!(0, store.count().await?); + rt.block_on(async { + for _ in 0..3i8 { + store.store_session(Session::new()).await?; + } - Ok(()) + assert_eq!(3, store.count().await?); + store.clear_store().await.unwrap(); + assert_eq!(0, store.count().await?); + + Ok(()) + }) } } } diff --git a/src/generic_handlers.rs b/src/generic_handlers.rs index cbefcb1..b60bebe 100644 --- a/src/generic_handlers.rs +++ b/src/generic_handlers.rs @@ -20,34 +20,47 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse { #[cfg(test)] mod test { + use std::future::IntoFuture; + use axum_test::TestServer; + use tokio::runtime::Runtime; use crate::db; - #[tokio::test] - async fn slash_is_ok() { - let pool = db::get_db_pool().await; + #[test] + fn slash_is_ok() { + let pool = db::get_db_pool(); let secret = [0u8; 64]; - let app = crate::app(pool.clone(), &secret).await.into_make_service(); + let rt = Runtime::new().unwrap(); + let app = rt + .block_on(crate::app(pool.clone(), &secret)) + .into_make_service(); let server = TestServer::new(app).unwrap(); - server.get("/").await.assert_status_ok(); + rt.block_on(server.get("/").into_future()) + .assert_status_ok(); } - #[tokio::test] - async fn not_found_is_303() { - let pool = db::get_db_pool().await; + #[test] + fn not_found_is_303() { + let pool = db::get_db_pool(); let secret = [0u8; 64]; - let app = crate::app(pool, &secret).await.into_make_service(); + + let rt = Runtime::new().unwrap(); + let app = rt + .block_on(crate::app(pool.clone(), &secret)) + .into_make_service(); let server = TestServer::new(app).unwrap(); assert_eq!( - server - .get("/no-actual-route") - .expect_failure() - .await - .status_code(), + rt.block_on( + server + .get("/no-actual-route") + .expect_failure() + .into_future() + ) + .status_code(), 303 ); } diff --git a/src/import_utils.rs b/src/import_utils.rs index 65734c9..05c3575 100644 --- a/src/import_utils.rs +++ b/src/import_utils.rs @@ -158,7 +158,7 @@ mod test { #[tokio::test] async fn ensure_omega_user() { - let p = crate::db::get_db_pool().await; + let p = crate::db::get_db_pool(); assert!(!check_omega_exists(&p).await); ensure_omega(&p).await; assert!(check_omega_exists(&p).await); diff --git a/src/login.rs b/src/login.rs index 4bd0a26..4b01011 100644 --- a/src/login.rs +++ b/src/login.rs @@ -105,6 +105,10 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse { LogoutSuccessPage } +//-************************************************************************ +// tests +//-************************************************************************ + #[cfg(test)] mod test { use crate::{ @@ -114,85 +118,128 @@ mod test { const LOGIN_FORM: &str = "username=test_user&password=a"; - #[tokio::test] - async fn get_login() { - let s = server().await; - let resp = s.get("/login").await; - let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); - assert_eq!(body, LoginPage::default().to_string()); + #[test] + fn get_login() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let s = server(); + rt.block_on(async { + let resp = s.get("/login").await; + let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); + assert_eq!(body, LoginPage::default().to_string()); + }) } - #[tokio::test] - async fn post_login_success() { - let s = server().await; + #[test] + fn post_login_success() { + let s = server(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); let body = massage(LOGIN_FORM); - let resp = s - .post("/login") - .expect_failure() - .content_type(FORM_CONTENT_TYPE) - .bytes(body) - .await; - assert_eq!(resp.status_code(), 303); + rt.block_on(async { + let resp = s + .post("/login") + .expect_failure() + .content_type(FORM_CONTENT_TYPE) + .bytes(body) + .await; + assert_eq!(resp.status_code(), 303); + }) } - #[tokio::test] - async fn post_login_bad_user() { - let s = server().await; + #[test] + fn post_login_bad_user() { + let s = server(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); let form = "username=test_LOSER&password=aaaa"; let body = massage(form); - let resp = s - .post("/login") - .expect_success() - .content_type(FORM_CONTENT_TYPE) - .bytes(body) - .await; - assert_eq!(resp.status_code(), 200); + rt.block_on(async { + let resp = s + .post("/login") + .expect_success() + .content_type(FORM_CONTENT_TYPE) + .bytes(body) + .await; + assert_eq!(resp.status_code(), 200); + }) } - #[tokio::test] - async fn post_login_bad_password() { - let s = server().await; + #[test] + fn post_login_bad_password() { + let s = server(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); let form = "username=test_user&password=bbbb"; let body = massage(form); - let resp = s - .post("/login") - .expect_success() - .content_type(FORM_CONTENT_TYPE) - .bytes(body) - .await; - assert_eq!(resp.status_code(), 200); + rt.block_on(async { + let resp = s + .post("/login") + .expect_success() + .content_type(FORM_CONTENT_TYPE) + .bytes(body) + .await; + assert_eq!(resp.status_code(), 200); + }) } - #[tokio::test] - async fn get_logout() { - let s = server().await; - let resp = s.get("/logout").await; - let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); - assert_eq!(body, LogoutPage.to_string()); + #[test] + fn get_logout() { + let s = server(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + let resp = s.get("/logout").await; + let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); + assert_eq!(body, LogoutPage.to_string()); + }) } - #[tokio::test] - async fn post_logout_not_logged_in() { - let s = server().await; - let resp = s.post("/logout").await; - resp.assert_status_ok(); - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let default = LogoutSuccessPage.to_string(); - assert_eq!(body, &default); + #[test] + fn post_logout_not_logged_in() { + let s = server(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + let resp = s.post("/logout").await; + resp.assert_status_ok(); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let default = LogoutSuccessPage.to_string(); + assert_eq!(body, &default); + }) } - #[tokio::test] - async fn post_logout_logged_in() { - let s = server().await; + #[test] + fn post_logout_logged_in() { + let s = server(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); // log in and prove it - { + rt.block_on(async { let body = massage(LOGIN_FORM); let resp = s .post("/login") @@ -210,11 +257,11 @@ mod test { let main_page = s.get("/").await; let body = std::str::from_utf8(main_page.bytes()).unwrap(); assert_eq!(&logged_in, body); - } - let resp = s.post("/logout").await; - let body = std::str::from_utf8(resp.bytes()).unwrap(); - let default = LogoutSuccessPage.to_string(); - assert_eq!(body, &default); + let resp = s.post("/logout").await; + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let default = LogoutSuccessPage.to_string(); + assert_eq!(body, &default); + }) } } diff --git a/src/main.rs b/src/main.rs index 350029b..4548be1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,12 @@ use std::net::SocketAddr; use rand::{thread_rng, RngCore}; +use sqlx::SqlitePool; use tokio::signal; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use what2watch::get_db_pool; -#[tokio::main] -async fn main() { +fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() @@ -15,8 +15,18 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); - let pool = get_db_pool().await; + let pool = get_db_pool(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(runner(&pool)); + rt.block_on(pool.close()); +} + +async fn runner(pool: &SqlitePool) { let secret = { let mut bytes = [0u8; 64]; let mut rng = thread_rng(); @@ -34,8 +44,6 @@ async fn main() { .with_graceful_shutdown(shutdown_signal()) .await .unwrap_or_default(); - - pool.close().await; } async fn shutdown_signal() { diff --git a/src/signup.rs b/src/signup.rs index a3f1275..2b9c6db 100644 --- a/src/signup.rs +++ b/src/signup.rs @@ -223,7 +223,7 @@ mod test { #[tokio::test] async fn post_create_user() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(GOOD_FORM); @@ -243,7 +243,7 @@ mod test { #[tokio::test] async fn get_create_user() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let resp = server.get("/signup").await; @@ -254,7 +254,7 @@ mod test { #[tokio::test] async fn handle_signup_success() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let user = get_test_user(); @@ -290,7 +290,7 @@ mod test { #[tokio::test] async fn password_mismatch() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(PASSWORD_MISMATCH_FORM); @@ -313,7 +313,7 @@ mod test { #[tokio::test] async fn password_short() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(PASSWORD_SHORT_FORM); @@ -336,7 +336,7 @@ mod test { #[tokio::test] async fn password_long() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(PASSWORD_LONG_FORM); @@ -363,7 +363,7 @@ mod test { // min length is 4 assert_eq!(pw.len(), 4); - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let form = format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}"); @@ -388,7 +388,7 @@ mod test { #[tokio::test] async fn username_short() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(USERNAME_SHORT_FORM); @@ -411,7 +411,7 @@ mod test { #[tokio::test] async fn username_long() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(USERNAME_LONG_FORM); @@ -434,7 +434,7 @@ mod test { #[tokio::test] async fn username_duplicate() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(GOOD_FORM); @@ -465,7 +465,7 @@ mod test { #[tokio::test] async fn displayname_long() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let body = massage(DISPLAYNAME_LONG_FORM); @@ -488,7 +488,7 @@ mod test { #[tokio::test] async fn handle_signup_success() { - let pool = get_db_pool().await; + let pool = get_db_pool(); let server = server_with_pool(&pool).await; let path = format!("/signup_success/nope"); diff --git a/src/test_utils.rs b/src/test_utils.rs index 9e022b0..b503301 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -17,33 +17,40 @@ pub fn get_test_user() -> User { } } -pub async fn server() -> TestServer { - let pool = crate::db::get_db_pool().await; +pub fn server() -> TestServer { + let pool = crate::db::get_db_pool(); let secret = [0u8; 64]; - let user = get_test_user(); - sqlx::query(crate::signup::CREATE_QUERY) - .bind(user.id) - .bind(&user.username) - .bind(&user.displayname) - .bind(&user.email) - .bind(&user.pwhash) - .execute(&pool) - .await + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() .unwrap(); - let r = sqlx::query("select count(*) from users") - .fetch_one(&pool) - .await; - assert!(r.is_ok()); + let user = get_test_user(); + rt.block_on(async { + sqlx::query(crate::signup::CREATE_QUERY) + .bind(user.id) + .bind(&user.username) + .bind(&user.displayname) + .bind(&user.email) + .bind(&user.pwhash) + .execute(&pool) + .await + .unwrap(); - let app = crate::app(pool, &secret).await.into_make_service(); + let r = sqlx::query("select count(*) from users") + .fetch_one(&pool) + .await; + assert!(r.is_ok()); - let config = TestServerConfig { - save_cookies: true, - ..Default::default() - }; - TestServer::new_with_config(app, config).unwrap() + let app = crate::app(pool, &secret).await.into_make_service(); + + let config = TestServerConfig { + save_cookies: true, + ..Default::default() + }; + TestServer::new_with_config(app, config).unwrap() + }) } pub async fn server_with_pool(pool: &SqlitePool) -> TestServer {