make db acquire sync; tests fail but app runs correctly

This commit is contained in:
Joe Ardent 2023-07-18 17:37:24 -07:00
parent c685dc1a6b
commit 48308aa169
9 changed files with 437 additions and 275 deletions

View file

@ -10,32 +10,47 @@ struct Cli {
pub db_path: OsString, pub db_path: OsString,
} }
#[tokio::main] fn main() {
async fn main() {
let cli = Cli::parse(); let cli = Cli::parse();
let path = cli.db_path; let path = cli.db_path;
let opts = SqliteConnectOptions::new().filename(&path).read_only(true); let opts = SqliteConnectOptions::new().filename(path).read_only(true);
let movie_db = SqlitePoolOptions::new() let movie_db = {
.idle_timeout(Duration::from_secs(90)) let rt = tokio::runtime::Builder::new_multi_thread()
.connect_with(opts) .enable_all()
.await .build()
.expect("could not open movies db"); .unwrap();
let w2w_db = get_db_pool().await; rt.block_on(
SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(90))
.connect_with(opts),
)
.expect("could not open movies db")
};
let w2w_db = get_db_pool();
let start = std::time::Instant::now(); let start = std::time::Instant::now();
add_omega_watches(&w2w_db, &movie_db).await.unwrap(); let rows = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
add_omega_watches(&w2w_db, &movie_db).await.unwrap();
let rows: i32 = sqlx::query_scalar("select count(*) from watches")
.fetch_one(&w2w_db)
.await
.unwrap();
w2w_db.close().await;
rows
})
};
let end = std::time::Instant::now(); let end = std::time::Instant::now();
let dur = (end - start).as_secs_f32(); let dur = (end - start).as_secs_f32();
let rows: i32 = sqlx::query_scalar("select count(*) from watches")
.fetch_one(&w2w_db)
.await
.unwrap();
println!("Added {rows} movies in {dur} seconds"); println!("Added {rows} movies in {dur} seconds");
w2w_db.close().await;
} }

View file

@ -7,7 +7,7 @@ use sqlx::{
sqlite::{SqliteConnectOptions, SqlitePoolOptions}, sqlite::{SqliteConnectOptions, SqlitePoolOptions},
SqlitePool, SqlitePool,
}; };
use tokio::task::JoinSet; use tokio::{runtime, task::JoinSet};
use tokio_retry::Retry; use tokio_retry::Retry;
use what2watch::{ use what2watch::{
get_db_pool, get_db_pool,
@ -15,8 +15,7 @@ use what2watch::{
DbId, User, WatchQuest, DbId, User, WatchQuest,
}; };
#[tokio::main] fn main() {
async fn main() {
let cli = Cli::parse(); let cli = Cli::parse();
let path = cli.db_path; let path = cli.db_path;
let num_users = cli.users; let num_users = cli.users;
@ -31,29 +30,39 @@ async fn main() {
let words: Vec<&str> = words.split('\n').collect(); let words: Vec<&str> = words.split('\n').collect();
let opts = SqliteConnectOptions::new().filename(&path).read_only(true); let opts = SqliteConnectOptions::new().filename(&path).read_only(true);
let movie_db = SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(3))
.connect_with(opts)
.await
.expect("could not open movies db");
let w2w_db = get_db_pool().await;
let users = &gen_users(num_users, &words, &w2w_db).await; let rt = runtime::Builder::new_multi_thread()
let movies = &add_omega_watches(&w2w_db, &movie_db).await.unwrap(); .enable_all()
.build()
.unwrap();
let movie_db = rt
.block_on(
SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(3))
.connect_with(opts),
)
.expect("could not open movies db");
let w2w_db = get_db_pool();
let users = &rt.block_on(gen_users(num_users, &words, &w2w_db));
let movies = &rt.block_on(add_omega_watches(&w2w_db, &movie_db)).unwrap();
let rng = &mut thread_rng(); let rng = &mut thread_rng();
let normal = Normal::new(mpu, mpu / 10.0).unwrap(); let normal = Normal::new(mpu, mpu / 10.0).unwrap();
let start = std::time::Instant::now(); let start = std::time::Instant::now();
for &user in users { rt.block_on(async {
add_quests(user, movies, &w2w_db, rng, normal).await; for &user in users {
} add_quests(user, movies, &w2w_db, rng, normal).await;
let end = std::time::Instant::now(); }
let rows: i32 = sqlx::query_scalar("select count(*) from watch_quests") });
.fetch_one(&w2w_db) let rows: i32 = rt
.await .block_on(sqlx::query_scalar("select count(*) from watch_quests").fetch_one(&w2w_db))
.unwrap(); .unwrap();
w2w_db.close().await; rt.block_on(w2w_db.close());
let end = std::time::Instant::now();
let dur = (end - start).as_secs_f32(); let dur = (end - start).as_secs_f32();
println!("Added {rows} quests in {dur} seconds"); println!("Added {rows} quests in {dur} seconds");
} }

313
src/db.rs
View file

@ -19,7 +19,7 @@ const MIN_CONNS: u32 = 5;
const TIMEOUT: u64 = 11; const TIMEOUT: u64 = 11;
const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64); const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64);
pub async fn get_db_pool() -> SqlitePool { pub fn get_db_pool() -> SqlitePool {
let db_filename = { let db_filename = {
std::env::var("DATABASE_FILE").unwrap_or_else(|_| { std::env::var("DATABASE_FILE").unwrap_or_else(|_| {
#[cfg(not(test))] #[cfg(not(test))]
@ -57,25 +57,34 @@ pub async fn get_db_pool() -> SqlitePool {
.min_connections(MIN_CONNS) .min_connections(MIN_CONNS)
.idle_timeout(Some(Duration::from_secs(30))) .idle_timeout(Some(Duration::from_secs(30)))
.max_lifetime(Some(Duration::from_secs(3600))) .max_lifetime(Some(Duration::from_secs(3600)))
.connect_with(conn_opts) .connect_with(conn_opts);
.await
.expect("can't connect to database");
// let the filesystem settle before trying anything let rt = tokio::runtime::Builder::new_multi_thread()
// possibly not effective? .enable_all()
tokio::time::sleep(Duration::from_millis(500)).await; .build()
.unwrap();
let pool = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(pool).unwrap()
};
let _con = rt.block_on(pool.acquire()).unwrap();
{ {
let mut m = Migrator::new(std::path::Path::new("./migrations")) let rt = tokio::runtime::Builder::new_multi_thread()
.await .enable_all()
.expect("Should be able to read the migration directory."); .build()
.unwrap();
let m = m.set_locking(true);
m.run(&pool)
.await
.expect("Should be able to run the migration.");
let m = Migrator::new(std::path::Path::new("./migrations"));
let mut m = rt.block_on(m).unwrap();
let m = m.set_locking(false);
rt.block_on(m.run(&pool)).unwrap();
tracing::info!("Ran migrations"); tracing::info!("Ran migrations");
} }
@ -120,7 +129,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn it_migrates_the_db() { async fn it_migrates_the_db() {
let db = super::get_db_pool().await; let db = super::get_db_pool();
let r = sqlx::query("select count(*) from users") let r = sqlx::query("select count(*) from users")
.fetch_one(&db) .fetch_one(&db)
.await; .await;
@ -373,169 +382,223 @@ mod session_store {
use super::*; use super::*;
async fn test_store() -> SqliteSessionStore { fn test_store() -> SqliteSessionStore {
let store = SqliteSessionStore::new("sqlite::memory:") let rt = tokio::runtime::Builder::new_multi_thread()
.await .enable_all()
.expect("building a sqlite :memory: SqliteSessionStore"); .build()
store .unwrap();
.migrate()
.await let store = rt.block_on(async {
.expect("migrating a brand new :memory: SqliteSessionStore"); let store = SqliteSessionStore::new("sqlite::memory:")
.await
.expect("building a sqlite :memory: SqliteSessionStore");
store
.migrate()
.await
.expect("migrating a brand new :memory: SqliteSessionStore");
store
});
dbg!("got the store");
store store
} }
#[tokio::test] #[test]
async fn creating_a_new_session_with_no_expiry() -> Result { fn creating_a_new_session_with_no_expiry() -> Result {
let store = test_store().await; let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
dbg!("new session");
session.insert("key", "value")?; session.insert("key", "value")?;
let cloned = session.clone(); let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) = rt.block_on(async {
sqlx::query_as("select id, expires, session, count(*) from async_sessions") let cookie_value = store.store_session(session).await?.unwrap();
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count); let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
assert_eq!(id, cloned.id()); sqlx::query_as("select id, expires, session, count(*) from async_sessions")
assert_eq!(expires, None); .fetch_one(&mut store.connection().await?)
.await?;
let deserialized_session: Session = serde_json::from_str(&serialized)?; assert_eq!(1, count);
assert_eq!(cloned.id(), deserialized_session.id()); assert_eq!(id, cloned.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap()); assert_eq!(expires, None);
let loaded_session = store.load_session(cookie_value).await?.unwrap(); let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), loaded_session.id()); assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap()); assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired()); let loaded_session = store.load_session(cookie_value).await?.unwrap();
Ok(()) assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
Ok(())
})
} }
#[tokio::test] #[test]
async fn updating_a_session() -> Result { fn updating_a_session() -> Result {
let store = test_store().await; let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
let original_id = session.id().to_owned(); let original_id = session.id().to_owned();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
session.insert("key", "value")?; session.insert("key", "value")?;
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); rt.block_on(async {
session.insert("key", "other value")?; let cookie_value = store.store_session(session).await?.unwrap();
assert_eq!(None, store.store_session(session).await?);
let session = store.load_session(cookie_value.clone()).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.get::<String>("key").unwrap(), "other value"); session.insert("key", "other value")?;
assert_eq!(None, store.store_session(session).await?);
let (id, count): (String, i64) = let session = store.load_session(cookie_value.clone()).await?.unwrap();
sqlx::query_as("select id, count(*) from async_sessions") assert_eq!(session.get::<String>("key").unwrap(), "other value");
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count); let (id, count): (String, i64) =
assert_eq!(original_id, id); sqlx::query_as("select id, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
Ok(()) assert_eq!(1, count);
assert_eq!(original_id, id);
Ok(())
})
} }
#[tokio::test] #[test]
async fn updating_a_session_extending_expiry() -> Result { fn updating_a_session_extending_expiry() -> Result {
let store = test_store().await; let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
session.expire_in(Duration::from_secs(10)); session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned(); let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone(); let original_expires = session.expiry().unwrap().clone();
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); rt.block_on(async {
assert_eq!(session.expiry().unwrap(), &original_expires); let cookie_value = store.store_session(session).await?.unwrap();
session.expire_in(Duration::from_secs(20));
let new_expires = session.expiry().unwrap().clone();
store.store_session(session).await?;
let session = store.load_session(cookie_value.clone()).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &new_expires); assert_eq!(session.expiry().unwrap(), &original_expires);
session.expire_in(Duration::from_secs(20));
let new_expires = session.expiry().unwrap().clone();
store.store_session(session).await?;
let (id, expires, count): (String, i64, i64) = let session = store.load_session(cookie_value.clone()).await?.unwrap();
sqlx::query_as("select id, expires, count(*) from async_sessions") assert_eq!(session.expiry().unwrap(), &new_expires);
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count); let (id, expires, count): (String, i64, i64) =
assert_eq!(expires, new_expires.timestamp()); sqlx::query_as("select id, expires, count(*) from async_sessions")
assert_eq!(original_id, id); .fetch_one(&mut store.connection().await?)
.await?;
Ok(()) assert_eq!(1, count);
assert_eq!(expires, new_expires.timestamp());
assert_eq!(original_id, id);
Ok(())
})
} }
#[tokio::test] #[test]
async fn creating_a_new_session_with_expiry() -> Result { fn creating_a_new_session_with_expiry() -> Result {
let store = test_store().await; let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
session.expire_in(Duration::from_secs(1)); session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?; session.insert("key", "value")?;
let cloned = session.clone(); let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap(); let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) = rt.block_on(async {
sqlx::query_as("select id, expires, session, count(*) from async_sessions") let cookie_value = store.store_session(session).await?.unwrap();
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count); let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
assert_eq!(id, cloned.id()); sqlx::query_as("select id, expires, session, count(*) from async_sessions")
assert!(expires.unwrap() > Utc::now().timestamp()); .fetch_one(&mut store.connection().await?)
.await?;
let deserialized_session: Session = serde_json::from_str(&serialized)?; assert_eq!(1, count);
assert_eq!(cloned.id(), deserialized_session.id()); assert_eq!(id, cloned.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap()); assert!(expires.unwrap() > Utc::now().timestamp());
let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), loaded_session.id()); assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap()); assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired()); let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
tokio::time::sleep(Duration::from_secs(1)).await; assert!(!loaded_session.is_expired());
assert_eq!(None, store.load_session(cookie_value).await?);
Ok(()) tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(None, store.load_session(cookie_value).await?);
Ok(())
})
} }
#[tokio::test] #[test]
async fn destroying_a_single_session() -> Result { fn destroying_a_single_session() -> Result {
let store = test_store().await; let store = test_store();
for _ in 0..3i8 { let rt = tokio::runtime::Builder::new_multi_thread()
store.store_session(Session::new()).await?; .enable_all()
} .build()
.unwrap();
let cookie = store.store_session(Session::new()).await?.unwrap(); rt.block_on(async {
assert_eq!(4, store.count().await?); for _ in 0..3i8 {
let session = store.load_session(cookie.clone()).await?.unwrap(); store.store_session(Session::new()).await?;
store.destroy_session(session.clone()).await.unwrap(); }
assert_eq!(None, store.load_session(cookie).await?);
assert_eq!(3, store.count().await?);
// // attempting to destroy the session again is not an error let cookie = store.store_session(Session::new()).await?.unwrap();
// assert!(store.destroy_session(session).await.is_ok()); assert_eq!(4, store.count().await?);
Ok(()) let session = store.load_session(cookie.clone()).await?.unwrap();
store.destroy_session(session.clone()).await.unwrap();
assert_eq!(None, store.load_session(cookie).await?);
assert_eq!(3, store.count().await?);
// // attempting to destroy the session again is not an error
// assert!(store.destroy_session(session).await.is_ok());
Ok(())
})
} }
#[tokio::test] #[test]
async fn clearing_the_whole_store() -> Result { fn clearing_the_whole_store() -> Result {
let store = test_store().await; let store = test_store();
for _ in 0..3i8 { let rt = tokio::runtime::Builder::new_multi_thread()
store.store_session(Session::new()).await?; .enable_all()
} .build()
.unwrap();
assert_eq!(3, store.count().await?); rt.block_on(async {
store.clear_store().await.unwrap(); for _ in 0..3i8 {
assert_eq!(0, store.count().await?); store.store_session(Session::new()).await?;
}
Ok(()) assert_eq!(3, store.count().await?);
store.clear_store().await.unwrap();
assert_eq!(0, store.count().await?);
Ok(())
})
} }
} }
} }

View file

@ -20,34 +20,47 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::future::IntoFuture;
use axum_test::TestServer; use axum_test::TestServer;
use tokio::runtime::Runtime;
use crate::db; use crate::db;
#[tokio::test] #[test]
async fn slash_is_ok() { fn slash_is_ok() {
let pool = db::get_db_pool().await; let pool = db::get_db_pool();
let secret = [0u8; 64]; let secret = [0u8; 64];
let app = crate::app(pool.clone(), &secret).await.into_make_service(); let rt = Runtime::new().unwrap();
let app = rt
.block_on(crate::app(pool.clone(), &secret))
.into_make_service();
let server = TestServer::new(app).unwrap(); let server = TestServer::new(app).unwrap();
server.get("/").await.assert_status_ok(); rt.block_on(server.get("/").into_future())
.assert_status_ok();
} }
#[tokio::test] #[test]
async fn not_found_is_303() { fn not_found_is_303() {
let pool = db::get_db_pool().await; let pool = db::get_db_pool();
let secret = [0u8; 64]; let secret = [0u8; 64];
let app = crate::app(pool, &secret).await.into_make_service();
let rt = Runtime::new().unwrap();
let app = rt
.block_on(crate::app(pool.clone(), &secret))
.into_make_service();
let server = TestServer::new(app).unwrap(); let server = TestServer::new(app).unwrap();
assert_eq!( assert_eq!(
server rt.block_on(
.get("/no-actual-route") server
.expect_failure() .get("/no-actual-route")
.await .expect_failure()
.status_code(), .into_future()
)
.status_code(),
303 303
); );
} }

View file

@ -158,7 +158,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn ensure_omega_user() { async fn ensure_omega_user() {
let p = crate::db::get_db_pool().await; let p = crate::db::get_db_pool();
assert!(!check_omega_exists(&p).await); assert!(!check_omega_exists(&p).await);
ensure_omega(&p).await; ensure_omega(&p).await;
assert!(check_omega_exists(&p).await); assert!(check_omega_exists(&p).await);

View file

@ -105,6 +105,10 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse {
LogoutSuccessPage LogoutSuccessPage
} }
//-************************************************************************
// tests
//-************************************************************************
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ use crate::{
@ -114,85 +118,128 @@ mod test {
const LOGIN_FORM: &str = "username=test_user&password=a"; const LOGIN_FORM: &str = "username=test_user&password=a";
#[tokio::test] #[test]
async fn get_login() { fn get_login() {
let s = server().await; let rt = tokio::runtime::Builder::new_multi_thread()
let resp = s.get("/login").await; .enable_all()
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); .build()
assert_eq!(body, LoginPage::default().to_string()); .unwrap();
let s = server();
rt.block_on(async {
let resp = s.get("/login").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LoginPage::default().to_string());
})
} }
#[tokio::test] #[test]
async fn post_login_success() { fn post_login_success() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let body = massage(LOGIN_FORM); let body = massage(LOGIN_FORM);
let resp = s rt.block_on(async {
.post("/login") let resp = s
.expect_failure() .post("/login")
.content_type(FORM_CONTENT_TYPE) .expect_failure()
.bytes(body) .content_type(FORM_CONTENT_TYPE)
.await; .bytes(body)
assert_eq!(resp.status_code(), 303); .await;
assert_eq!(resp.status_code(), 303);
})
} }
#[tokio::test] #[test]
async fn post_login_bad_user() { fn post_login_bad_user() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let form = "username=test_LOSER&password=aaaa"; let form = "username=test_LOSER&password=aaaa";
let body = massage(form); let body = massage(form);
let resp = s rt.block_on(async {
.post("/login") let resp = s
.expect_success() .post("/login")
.content_type(FORM_CONTENT_TYPE) .expect_success()
.bytes(body) .content_type(FORM_CONTENT_TYPE)
.await; .bytes(body)
assert_eq!(resp.status_code(), 200); .await;
assert_eq!(resp.status_code(), 200);
})
} }
#[tokio::test] #[test]
async fn post_login_bad_password() { fn post_login_bad_password() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let form = "username=test_user&password=bbbb"; let form = "username=test_user&password=bbbb";
let body = massage(form); let body = massage(form);
let resp = s rt.block_on(async {
.post("/login") let resp = s
.expect_success() .post("/login")
.content_type(FORM_CONTENT_TYPE) .expect_success()
.bytes(body) .content_type(FORM_CONTENT_TYPE)
.await; .bytes(body)
assert_eq!(resp.status_code(), 200); .await;
assert_eq!(resp.status_code(), 200);
})
} }
#[tokio::test] #[test]
async fn get_logout() { fn get_logout() {
let s = server().await; let s = server();
let resp = s.get("/logout").await; let rt = tokio::runtime::Builder::new_multi_thread()
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); .enable_all()
assert_eq!(body, LogoutPage.to_string()); .build()
.unwrap();
rt.block_on(async {
let resp = s.get("/logout").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LogoutPage.to_string());
})
} }
#[tokio::test] #[test]
async fn post_logout_not_logged_in() { fn post_logout_not_logged_in() {
let s = server().await; let s = server();
let resp = s.post("/logout").await; let rt = tokio::runtime::Builder::new_multi_thread()
resp.assert_status_ok(); .enable_all()
let body = std::str::from_utf8(resp.bytes()).unwrap(); .build()
let default = LogoutSuccessPage.to_string(); .unwrap();
assert_eq!(body, &default);
rt.block_on(async {
let resp = s.post("/logout").await;
resp.assert_status_ok();
let body = std::str::from_utf8(resp.bytes()).unwrap();
let default = LogoutSuccessPage.to_string();
assert_eq!(body, &default);
})
} }
#[tokio::test] #[test]
async fn post_logout_logged_in() { fn post_logout_logged_in() {
let s = server().await; let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
// log in and prove it // log in and prove it
{ rt.block_on(async {
let body = massage(LOGIN_FORM); let body = massage(LOGIN_FORM);
let resp = s let resp = s
.post("/login") .post("/login")
@ -210,11 +257,11 @@ mod test {
let main_page = s.get("/").await; let main_page = s.get("/").await;
let body = std::str::from_utf8(main_page.bytes()).unwrap(); let body = std::str::from_utf8(main_page.bytes()).unwrap();
assert_eq!(&logged_in, body); assert_eq!(&logged_in, body);
}
let resp = s.post("/logout").await; let resp = s.post("/logout").await;
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let default = LogoutSuccessPage.to_string(); let default = LogoutSuccessPage.to_string();
assert_eq!(body, &default); assert_eq!(body, &default);
})
} }
} }

View file

@ -1,12 +1,12 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use rand::{thread_rng, RngCore}; use rand::{thread_rng, RngCore};
use sqlx::SqlitePool;
use tokio::signal; use tokio::signal;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use what2watch::get_db_pool; use what2watch::get_db_pool;
#[tokio::main] fn main() {
async fn main() {
tracing_subscriber::registry() tracing_subscriber::registry()
.with( .with(
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env()
@ -15,8 +15,18 @@ async fn main() {
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
let pool = get_db_pool().await; let pool = get_db_pool();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(runner(&pool));
rt.block_on(pool.close());
}
async fn runner(pool: &SqlitePool) {
let secret = { let secret = {
let mut bytes = [0u8; 64]; let mut bytes = [0u8; 64];
let mut rng = thread_rng(); let mut rng = thread_rng();
@ -34,8 +44,6 @@ async fn main() {
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await
.unwrap_or_default(); .unwrap_or_default();
pool.close().await;
} }
async fn shutdown_signal() { async fn shutdown_signal() {

View file

@ -223,7 +223,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn post_create_user() { async fn post_create_user() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM); let body = massage(GOOD_FORM);
@ -243,7 +243,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn get_create_user() { async fn get_create_user() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let resp = server.get("/signup").await; let resp = server.get("/signup").await;
@ -254,7 +254,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn handle_signup_success() { async fn handle_signup_success() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let user = get_test_user(); let user = get_test_user();
@ -290,7 +290,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn password_mismatch() { async fn password_mismatch() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_MISMATCH_FORM); let body = massage(PASSWORD_MISMATCH_FORM);
@ -313,7 +313,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn password_short() { async fn password_short() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_SHORT_FORM); let body = massage(PASSWORD_SHORT_FORM);
@ -336,7 +336,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn password_long() { async fn password_long() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_LONG_FORM); let body = massage(PASSWORD_LONG_FORM);
@ -363,7 +363,7 @@ mod test {
// min length is 4 // min length is 4
assert_eq!(pw.len(), 4); assert_eq!(pw.len(), 4);
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let form = let form =
format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}"); format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}");
@ -388,7 +388,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn username_short() { async fn username_short() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(USERNAME_SHORT_FORM); let body = massage(USERNAME_SHORT_FORM);
@ -411,7 +411,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn username_long() { async fn username_long() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(USERNAME_LONG_FORM); let body = massage(USERNAME_LONG_FORM);
@ -434,7 +434,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn username_duplicate() { async fn username_duplicate() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM); let body = massage(GOOD_FORM);
@ -465,7 +465,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn displayname_long() { async fn displayname_long() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(DISPLAYNAME_LONG_FORM); let body = massage(DISPLAYNAME_LONG_FORM);
@ -488,7 +488,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn handle_signup_success() { async fn handle_signup_success() {
let pool = get_db_pool().await; let pool = get_db_pool();
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let path = format!("/signup_success/nope"); let path = format!("/signup_success/nope");

View file

@ -17,33 +17,40 @@ pub fn get_test_user() -> User {
} }
} }
pub async fn server() -> TestServer { pub fn server() -> TestServer {
let pool = crate::db::get_db_pool().await; let pool = crate::db::get_db_pool();
let secret = [0u8; 64]; let secret = [0u8; 64];
let user = get_test_user(); let rt = tokio::runtime::Builder::new_multi_thread()
sqlx::query(crate::signup::CREATE_QUERY) .enable_all()
.bind(user.id) .build()
.bind(&user.username)
.bind(&user.displayname)
.bind(&user.email)
.bind(&user.pwhash)
.execute(&pool)
.await
.unwrap(); .unwrap();
let r = sqlx::query("select count(*) from users") let user = get_test_user();
.fetch_one(&pool) rt.block_on(async {
.await; sqlx::query(crate::signup::CREATE_QUERY)
assert!(r.is_ok()); .bind(user.id)
.bind(&user.username)
.bind(&user.displayname)
.bind(&user.email)
.bind(&user.pwhash)
.execute(&pool)
.await
.unwrap();
let app = crate::app(pool, &secret).await.into_make_service(); let r = sqlx::query("select count(*) from users")
.fetch_one(&pool)
.await;
assert!(r.is_ok());
let config = TestServerConfig { let app = crate::app(pool, &secret).await.into_make_service();
save_cookies: true,
..Default::default() let config = TestServerConfig {
}; save_cookies: true,
TestServer::new_with_config(app, config).unwrap() ..Default::default()
};
TestServer::new_with_config(app, config).unwrap()
})
} }
pub async fn server_with_pool(pool: &SqlitePool) -> TestServer { pub async fn server_with_pool(pool: &SqlitePool) -> TestServer {