make db acquire sync; tests fail but app runs correctly

This commit is contained in:
Joe Ardent 2023-07-18 17:37:24 -07:00
parent c685dc1a6b
commit 48308aa169
9 changed files with 437 additions and 275 deletions

View file

@ -10,32 +10,47 @@ struct Cli {
pub db_path: OsString, pub db_path: OsString,
} }
#[tokio::main] fn main() {
async fn main() {
let cli = Cli::parse(); let cli = Cli::parse();
let path = cli.db_path; let path = cli.db_path;
let opts = SqliteConnectOptions::new().filename(&path).read_only(true); let opts = SqliteConnectOptions::new().filename(path).read_only(true);
let movie_db = SqlitePoolOptions::new() let movie_db = {
.idle_timeout(Duration::from_secs(90)) let rt = tokio::runtime::Builder::new_multi_thread()
.connect_with(opts) .enable_all()
.await .build()
.expect("could not open movies db"); .unwrap();
let w2w_db = get_db_pool().await; rt.block_on(
SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(90))
.connect_with(opts),
)
.expect("could not open movies db")
};
let w2w_db = get_db_pool();
let start = std::time::Instant::now(); let start = std::time::Instant::now();
add_omega_watches(&w2w_db, &movie_db).await.unwrap(); let rows = {
let end = std::time::Instant::now(); let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let dur = (end - start).as_secs_f32(); rt.block_on(async {
add_omega_watches(&w2w_db, &movie_db).await.unwrap();
let rows: i32 = sqlx::query_scalar("select count(*) from watches") let rows: i32 = sqlx::query_scalar("select count(*) from watches")
.fetch_one(&w2w_db) .fetch_one(&w2w_db)
.await .await
.unwrap(); .unwrap();
println!("Added {rows} movies in {dur} seconds");
w2w_db.close().await; w2w_db.close().await;
rows
})
};
let end = std::time::Instant::now();
let dur = (end - start).as_secs_f32();
println!("Added {rows} movies in {dur} seconds");
} }

View file

@ -7,7 +7,7 @@ use sqlx::{
sqlite::{SqliteConnectOptions, SqlitePoolOptions}, sqlite::{SqliteConnectOptions, SqlitePoolOptions},
SqlitePool, SqlitePool,
}; };
use tokio::task::JoinSet; use tokio::{runtime, task::JoinSet};
use tokio_retry::Retry; use tokio_retry::Retry;
use what2watch::{ use what2watch::{
get_db_pool, get_db_pool,
@ -15,8 +15,7 @@ use what2watch::{
DbId, User, WatchQuest, DbId, User, WatchQuest,
}; };
#[tokio::main] fn main() {
async fn main() {
let cli = Cli::parse(); let cli = Cli::parse();
let path = cli.db_path; let path = cli.db_path;
let num_users = cli.users; let num_users = cli.users;
@ -31,29 +30,39 @@ async fn main() {
let words: Vec<&str> = words.split('\n').collect(); let words: Vec<&str> = words.split('\n').collect();
let opts = SqliteConnectOptions::new().filename(&path).read_only(true); let opts = SqliteConnectOptions::new().filename(&path).read_only(true);
let movie_db = SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(3))
.connect_with(opts)
.await
.expect("could not open movies db");
let w2w_db = get_db_pool().await;
let users = &gen_users(num_users, &words, &w2w_db).await; let rt = runtime::Builder::new_multi_thread()
let movies = &add_omega_watches(&w2w_db, &movie_db).await.unwrap(); .enable_all()
.build()
.unwrap();
let movie_db = rt
.block_on(
SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(3))
.connect_with(opts),
)
.expect("could not open movies db");
let w2w_db = get_db_pool();
let users = &rt.block_on(gen_users(num_users, &words, &w2w_db));
let movies = &rt.block_on(add_omega_watches(&w2w_db, &movie_db)).unwrap();
let rng = &mut thread_rng(); let rng = &mut thread_rng();
let normal = Normal::new(mpu, mpu / 10.0).unwrap(); let normal = Normal::new(mpu, mpu / 10.0).unwrap();
let start = std::time::Instant::now(); let start = std::time::Instant::now();
rt.block_on(async {
for &user in users { for &user in users {
add_quests(user, movies, &w2w_db, rng, normal).await; add_quests(user, movies, &w2w_db, rng, normal).await;
} }
let end = std::time::Instant::now(); });
let rows: i32 = sqlx::query_scalar("select count(*) from watch_quests") let rows: i32 = rt
.fetch_one(&w2w_db) .block_on(sqlx::query_scalar("select count(*) from watch_quests").fetch_one(&w2w_db))
.await
.unwrap(); .unwrap();
w2w_db.close().await; rt.block_on(w2w_db.close());
let end = std::time::Instant::now();
let dur = (end - start).as_secs_f32(); let dur = (end - start).as_secs_f32();
println!("Added {rows} quests in {dur} seconds"); println!("Added {rows} quests in {dur} seconds");
} }

135
src/db.rs
View file

@ -19,7 +19,7 @@ const MIN_CONNS: u32 = 5;
const TIMEOUT: u64 = 11; const TIMEOUT: u64 = 11;
const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64); const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64);
pub async fn get_db_pool() -> SqlitePool { pub fn get_db_pool() -> SqlitePool {
let db_filename = { let db_filename = {
std::env::var("DATABASE_FILE").unwrap_or_else(|_| { std::env::var("DATABASE_FILE").unwrap_or_else(|_| {
#[cfg(not(test))] #[cfg(not(test))]
@ -57,25 +57,34 @@ pub async fn get_db_pool() -> SqlitePool {
.min_connections(MIN_CONNS) .min_connections(MIN_CONNS)
.idle_timeout(Some(Duration::from_secs(30))) .idle_timeout(Some(Duration::from_secs(30)))
.max_lifetime(Some(Duration::from_secs(3600))) .max_lifetime(Some(Duration::from_secs(3600)))
.connect_with(conn_opts) .connect_with(conn_opts);
.await
.expect("can't connect to database");
// let the filesystem settle before trying anything let rt = tokio::runtime::Builder::new_multi_thread()
// possibly not effective? .enable_all()
tokio::time::sleep(Duration::from_millis(500)).await; .build()
.unwrap();
let pool = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(pool).unwrap()
};
let _con = rt.block_on(pool.acquire()).unwrap();
{ {
let mut m = Migrator::new(std::path::Path::new("./migrations")) let rt = tokio::runtime::Builder::new_multi_thread()
.await .enable_all()
.expect("Should be able to read the migration directory."); .build()
.unwrap();
let m = m.set_locking(true);
m.run(&pool)
.await
.expect("Should be able to run the migration.");
let m = Migrator::new(std::path::Path::new("./migrations"));
let mut m = rt.block_on(m).unwrap();
let m = m.set_locking(false);
rt.block_on(m.run(&pool)).unwrap();
tracing::info!("Ran migrations"); tracing::info!("Ran migrations");
} }
@ -120,7 +129,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn it_migrates_the_db() { async fn it_migrates_the_db() {
let db = super::get_db_pool().await; let db = super::get_db_pool();
let r = sqlx::query("select count(*) from users") let r = sqlx::query("select count(*) from users")
.fetch_one(&db) .fetch_one(&db)
.await; .await;
@ -373,7 +382,13 @@ mod session_store {
use super::*; use super::*;
async fn test_store() -> SqliteSessionStore { fn test_store() -> SqliteSessionStore {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = rt.block_on(async {
let store = SqliteSessionStore::new("sqlite::memory:") let store = SqliteSessionStore::new("sqlite::memory:")
.await .await
.expect("building a sqlite :memory: SqliteSessionStore"); .expect("building a sqlite :memory: SqliteSessionStore");
@ -382,14 +397,25 @@ mod session_store {
.await .await
.expect("migrating a brand new :memory: SqliteSessionStore"); .expect("migrating a brand new :memory: SqliteSessionStore");
store store
});
dbg!("got the store");
store
} }
#[tokio::test] #[test]
async fn creating_a_new_session_with_no_expiry() -> Result { fn creating_a_new_session_with_no_expiry() -> Result {
let store = test_store().await; let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
dbg!("new session");
session.insert("key", "value")?; session.insert("key", "value")?;
let cloned = session.clone(); let cloned = session.clone();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) = let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
@ -411,15 +437,22 @@ mod session_store {
assert!(!loaded_session.is_expired()); assert!(!loaded_session.is_expired());
Ok(()) Ok(())
})
} }
#[tokio::test] #[test]
async fn updating_a_session() -> Result { fn updating_a_session() -> Result {
let store = test_store().await; let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
let original_id = session.id().to_owned(); let original_id = session.id().to_owned();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
session.insert("key", "value")?; session.insert("key", "value")?;
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
@ -438,15 +471,23 @@ mod session_store {
assert_eq!(original_id, id); assert_eq!(original_id, id);
Ok(()) Ok(())
})
} }
#[tokio::test] #[test]
async fn updating_a_session_extending_expiry() -> Result { fn updating_a_session_extending_expiry() -> Result {
let store = test_store().await; let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
session.expire_in(Duration::from_secs(10)); session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned(); let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone(); let original_expires = session.expiry().unwrap().clone();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
@ -468,16 +509,23 @@ mod session_store {
assert_eq!(original_id, id); assert_eq!(original_id, id);
Ok(()) Ok(())
})
} }
#[tokio::test] #[test]
async fn creating_a_new_session_with_expiry() -> Result { fn creating_a_new_session_with_expiry() -> Result {
let store = test_store().await; let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
session.expire_in(Duration::from_secs(1)); session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?; session.insert("key", "value")?;
let cloned = session.clone(); let cloned = session.clone();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) = let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
@ -503,11 +551,18 @@ mod session_store {
assert_eq!(None, store.load_session(cookie_value).await?); assert_eq!(None, store.load_session(cookie_value).await?);
Ok(()) Ok(())
})
} }
#[tokio::test] #[test]
async fn destroying_a_single_session() -> Result { fn destroying_a_single_session() -> Result {
let store = test_store().await; let store = test_store();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
for _ in 0..3i8 { for _ in 0..3i8 {
store.store_session(Session::new()).await?; store.store_session(Session::new()).await?;
} }
@ -522,11 +577,18 @@ mod session_store {
// // attempting to destroy the session again is not an error // // attempting to destroy the session again is not an error
// assert!(store.destroy_session(session).await.is_ok()); // assert!(store.destroy_session(session).await.is_ok());
Ok(()) Ok(())
})
} }
#[tokio::test] #[test]
async fn clearing_the_whole_store() -> Result { fn clearing_the_whole_store() -> Result {
let store = test_store().await; let store = test_store();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
for _ in 0..3i8 { for _ in 0..3i8 {
store.store_session(Session::new()).await?; store.store_session(Session::new()).await?;
} }
@ -536,6 +598,7 @@ mod session_store {
assert_eq!(0, store.count().await?); assert_eq!(0, store.count().await?);
Ok(()) Ok(())
})
} }
} }
} }

View file

@ -20,33 +20,46 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::future::IntoFuture;
use axum_test::TestServer; use axum_test::TestServer;
use tokio::runtime::Runtime;
use crate::db; use crate::db;
#[tokio::test] #[test]
async fn slash_is_ok() { fn slash_is_ok() {
let pool = db::get_db_pool().await; let pool = db::get_db_pool();
let secret = [0u8; 64]; let secret = [0u8; 64];
let app = crate::app(pool.clone(), &secret).await.into_make_service(); let rt = Runtime::new().unwrap();
let app = rt
.block_on(crate::app(pool.clone(), &secret))
.into_make_service();
let server = TestServer::new(app).unwrap(); let server = TestServer::new(app).unwrap();
server.get("/").await.assert_status_ok(); rt.block_on(server.get("/").into_future())
.assert_status_ok();
} }
#[tokio::test] #[test]
async fn not_found_is_303() { fn not_found_is_303() {
let pool = db::get_db_pool().await; let pool = db::get_db_pool();
let secret = [0u8; 64]; let secret = [0u8; 64];
let app = crate::app(pool, &secret).await.into_make_service();
let rt = Runtime::new().unwrap();
let app = rt
.block_on(crate::app(pool.clone(), &secret))
.into_make_service();
let server = TestServer::new(app).unwrap(); let server = TestServer::new(app).unwrap();
assert_eq!( assert_eq!(
rt.block_on(
server server
.get("/no-actual-route") .get("/no-actual-route")
.expect_failure() .expect_failure()
.await .into_future()
)
.status_code(), .status_code(),
303 303
); );

View file

@ -158,7 +158,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn ensure_omega_user() { async fn ensure_omega_user() {
let p = crate::db::get_db_pool().await; let p = crate::db::get_db_pool();
assert!(!check_omega_exists(&p).await); assert!(!check_omega_exists(&p).await);
ensure_omega(&p).await; ensure_omega(&p).await;
assert!(check_omega_exists(&p).await); assert!(check_omega_exists(&p).await);

View file

@ -105,6 +105,10 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse {
LogoutSuccessPage LogoutSuccessPage
} }
//-************************************************************************
// tests
//-************************************************************************
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ use crate::{
@ -114,20 +118,32 @@ mod test {
const LOGIN_FORM: &str = "username=test_user&password=a"; const LOGIN_FORM: &str = "username=test_user&password=a";
#[tokio::test] #[test]
async fn get_login() { fn get_login() {
let s = server().await; let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let s = server();
rt.block_on(async {
let resp = s.get("/login").await; let resp = s.get("/login").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LoginPage::default().to_string()); assert_eq!(body, LoginPage::default().to_string());
})
} }
#[tokio::test] #[test]
async fn post_login_success() { fn post_login_success() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let body = massage(LOGIN_FORM); let body = massage(LOGIN_FORM);
rt.block_on(async {
let resp = s let resp = s
.post("/login") .post("/login")
.expect_failure() .expect_failure()
@ -135,15 +151,21 @@ mod test {
.bytes(body) .bytes(body)
.await; .await;
assert_eq!(resp.status_code(), 303); assert_eq!(resp.status_code(), 303);
})
} }
#[tokio::test] #[test]
async fn post_login_bad_user() { fn post_login_bad_user() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let form = "username=test_LOSER&password=aaaa"; let form = "username=test_LOSER&password=aaaa";
let body = massage(form); let body = massage(form);
rt.block_on(async {
let resp = s let resp = s
.post("/login") .post("/login")
.expect_success() .expect_success()
@ -151,15 +173,21 @@ mod test {
.bytes(body) .bytes(body)
.await; .await;
assert_eq!(resp.status_code(), 200); assert_eq!(resp.status_code(), 200);
})
} }
#[tokio::test] #[test]
async fn post_login_bad_password() { fn post_login_bad_password() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let form = "username=test_user&password=bbbb"; let form = "username=test_user&password=bbbb";
let body = massage(form); let body = massage(form);
rt.block_on(async {
let resp = s let resp = s
.post("/login") .post("/login")
.expect_success() .expect_success()
@ -167,32 +195,51 @@ mod test {
.bytes(body) .bytes(body)
.await; .await;
assert_eq!(resp.status_code(), 200); assert_eq!(resp.status_code(), 200);
})
} }
#[tokio::test] #[test]
async fn get_logout() { fn get_logout() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let resp = s.get("/logout").await; let resp = s.get("/logout").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LogoutPage.to_string()); assert_eq!(body, LogoutPage.to_string());
})
} }
#[tokio::test] #[test]
async fn post_logout_not_logged_in() { fn post_logout_not_logged_in() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let resp = s.post("/logout").await; let resp = s.post("/logout").await;
resp.assert_status_ok(); resp.assert_status_ok();
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let default = LogoutSuccessPage.to_string(); let default = LogoutSuccessPage.to_string();
assert_eq!(body, &default); assert_eq!(body, &default);
})
} }
#[tokio::test] #[test]
async fn post_logout_logged_in() { fn post_logout_logged_in() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
// log in and prove it // log in and prove it
{ rt.block_on(async {
let body = massage(LOGIN_FORM); let body = massage(LOGIN_FORM);
let resp = s let resp = s
.post("/login") .post("/login")
@ -210,11 +257,11 @@ mod test {
let main_page = s.get("/").await; let main_page = s.get("/").await;
let body = std::str::from_utf8(main_page.bytes()).unwrap(); let body = std::str::from_utf8(main_page.bytes()).unwrap();
assert_eq!(&logged_in, body); assert_eq!(&logged_in, body);
}
let resp = s.post("/logout").await; let resp = s.post("/logout").await;
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let default = LogoutSuccessPage.to_string(); let default = LogoutSuccessPage.to_string();
assert_eq!(body, &default); assert_eq!(body, &default);
})
} }
} }

View file

@ -1,12 +1,12 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use rand::{thread_rng, RngCore}; use rand::{thread_rng, RngCore};
use sqlx::SqlitePool;
use tokio::signal; use tokio::signal;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use what2watch::get_db_pool; use what2watch::get_db_pool;
#[tokio::main] fn main() {
async fn main() {
tracing_subscriber::registry() tracing_subscriber::registry()
.with( .with(
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env()
@ -15,8 +15,18 @@ async fn main() {
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
let pool = get_db_pool().await; let pool = get_db_pool();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(runner(&pool));
rt.block_on(pool.close());
}
async fn runner(pool: &SqlitePool) {
let secret = { let secret = {
let mut bytes = [0u8; 64]; let mut bytes = [0u8; 64];
let mut rng = thread_rng(); let mut rng = thread_rng();
@ -34,8 +44,6 @@ async fn main() {
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await
.unwrap_or_default(); .unwrap_or_default();
pool.close().await;
} }
async fn shutdown_signal() { async fn shutdown_signal() {

View file

@ -223,7 +223,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn post_create_user() { async fn post_create_user() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM); let body = massage(GOOD_FORM);
@ -243,7 +243,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn get_create_user() { async fn get_create_user() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let resp = server.get("/signup").await; let resp = server.get("/signup").await;
@ -254,7 +254,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn handle_signup_success() { async fn handle_signup_success() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let user = get_test_user(); let user = get_test_user();
@ -290,7 +290,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn password_mismatch() { async fn password_mismatch() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_MISMATCH_FORM); let body = massage(PASSWORD_MISMATCH_FORM);
@ -313,7 +313,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn password_short() { async fn password_short() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_SHORT_FORM); let body = massage(PASSWORD_SHORT_FORM);
@ -336,7 +336,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn password_long() { async fn password_long() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_LONG_FORM); let body = massage(PASSWORD_LONG_FORM);
@ -363,7 +363,7 @@ mod test {
// min length is 4 // min length is 4
assert_eq!(pw.len(), 4); assert_eq!(pw.len(), 4);
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let form = let form =
format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}"); format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}");
@ -388,7 +388,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn username_short() { async fn username_short() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(USERNAME_SHORT_FORM); let body = massage(USERNAME_SHORT_FORM);
@ -411,7 +411,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn username_long() { async fn username_long() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(USERNAME_LONG_FORM); let body = massage(USERNAME_LONG_FORM);
@ -434,7 +434,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn username_duplicate() { async fn username_duplicate() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM); let body = massage(GOOD_FORM);
@ -465,7 +465,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn displayname_long() { async fn displayname_long() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(DISPLAYNAME_LONG_FORM); let body = massage(DISPLAYNAME_LONG_FORM);
@ -488,7 +488,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn handle_signup_success() { async fn handle_signup_success() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let path = format!("/signup_success/nope"); let path = format!("/signup_success/nope");

View file

@ -17,11 +17,17 @@ pub fn get_test_user() -> User {
} }
} }
pub async fn server() -> TestServer { pub fn server() -> TestServer {
let pool = crate::db::get_db_pool().await; let pool = crate::db::get_db_pool();
let secret = [0u8; 64]; let secret = [0u8; 64];
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let user = get_test_user(); let user = get_test_user();
rt.block_on(async {
sqlx::query(crate::signup::CREATE_QUERY) sqlx::query(crate::signup::CREATE_QUERY)
.bind(user.id) .bind(user.id)
.bind(&user.username) .bind(&user.username)
@ -44,6 +50,7 @@ pub async fn server() -> TestServer {
..Default::default() ..Default::default()
}; };
TestServer::new_with_config(app, config).unwrap() TestServer::new_with_config(app, config).unwrap()
})
} }
pub async fn server_with_pool(pool: &SqlitePool) -> TestServer { pub async fn server_with_pool(pool: &SqlitePool) -> TestServer {