make db acquire sync; tests fail but app runs correctly

This commit is contained in:
Joe Ardent 2023-07-18 17:37:24 -07:00
parent c685dc1a6b
commit 48308aa169
9 changed files with 437 additions and 275 deletions

View file

@ -10,32 +10,47 @@ struct Cli {
pub db_path: OsString,
}
#[tokio::main]
async fn main() {
fn main() {
let cli = Cli::parse();
let path = cli.db_path;
let opts = SqliteConnectOptions::new().filename(&path).read_only(true);
let movie_db = SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(90))
.connect_with(opts)
.await
.expect("could not open movies db");
let opts = SqliteConnectOptions::new().filename(path).read_only(true);
let movie_db = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let w2w_db = get_db_pool().await;
rt.block_on(
SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(90))
.connect_with(opts),
)
.expect("could not open movies db")
};
let w2w_db = get_db_pool();
let start = std::time::Instant::now();
add_omega_watches(&w2w_db, &movie_db).await.unwrap();
let end = std::time::Instant::now();
let rows = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let dur = (end - start).as_secs_f32();
rt.block_on(async {
add_omega_watches(&w2w_db, &movie_db).await.unwrap();
let rows: i32 = sqlx::query_scalar("select count(*) from watches")
.fetch_one(&w2w_db)
.await
.unwrap();
println!("Added {rows} movies in {dur} seconds");
w2w_db.close().await;
rows
})
};
let end = std::time::Instant::now();
let dur = (end - start).as_secs_f32();
println!("Added {rows} movies in {dur} seconds");
}

View file

@ -7,7 +7,7 @@ use sqlx::{
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
SqlitePool,
};
use tokio::task::JoinSet;
use tokio::{runtime, task::JoinSet};
use tokio_retry::Retry;
use what2watch::{
get_db_pool,
@ -15,8 +15,7 @@ use what2watch::{
DbId, User, WatchQuest,
};
#[tokio::main]
async fn main() {
fn main() {
let cli = Cli::parse();
let path = cli.db_path;
let num_users = cli.users;
@ -31,29 +30,39 @@ async fn main() {
let words: Vec<&str> = words.split('\n').collect();
let opts = SqliteConnectOptions::new().filename(&path).read_only(true);
let movie_db = SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(3))
.connect_with(opts)
.await
.expect("could not open movies db");
let w2w_db = get_db_pool().await;
let users = &gen_users(num_users, &words, &w2w_db).await;
let movies = &add_omega_watches(&w2w_db, &movie_db).await.unwrap();
let rt = runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let movie_db = rt
.block_on(
SqlitePoolOptions::new()
.idle_timeout(Duration::from_secs(3))
.connect_with(opts),
)
.expect("could not open movies db");
let w2w_db = get_db_pool();
let users = &rt.block_on(gen_users(num_users, &words, &w2w_db));
let movies = &rt.block_on(add_omega_watches(&w2w_db, &movie_db)).unwrap();
let rng = &mut thread_rng();
let normal = Normal::new(mpu, mpu / 10.0).unwrap();
let start = std::time::Instant::now();
rt.block_on(async {
for &user in users {
add_quests(user, movies, &w2w_db, rng, normal).await;
}
let end = std::time::Instant::now();
let rows: i32 = sqlx::query_scalar("select count(*) from watch_quests")
.fetch_one(&w2w_db)
.await
});
let rows: i32 = rt
.block_on(sqlx::query_scalar("select count(*) from watch_quests").fetch_one(&w2w_db))
.unwrap();
w2w_db.close().await;
rt.block_on(w2w_db.close());
let end = std::time::Instant::now();
let dur = (end - start).as_secs_f32();
println!("Added {rows} quests in {dur} seconds");
}

135
src/db.rs
View file

@ -19,7 +19,7 @@ const MIN_CONNS: u32 = 5;
const TIMEOUT: u64 = 11;
const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64);
pub async fn get_db_pool() -> SqlitePool {
pub fn get_db_pool() -> SqlitePool {
let db_filename = {
std::env::var("DATABASE_FILE").unwrap_or_else(|_| {
#[cfg(not(test))]
@ -57,25 +57,34 @@ pub async fn get_db_pool() -> SqlitePool {
.min_connections(MIN_CONNS)
.idle_timeout(Some(Duration::from_secs(30)))
.max_lifetime(Some(Duration::from_secs(3600)))
.connect_with(conn_opts)
.await
.expect("can't connect to database");
.connect_with(conn_opts);
// let the filesystem settle before trying anything
// possibly not effective?
tokio::time::sleep(Duration::from_millis(500)).await;
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let pool = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(pool).unwrap()
};
let _con = rt.block_on(pool.acquire()).unwrap();
{
let mut m = Migrator::new(std::path::Path::new("./migrations"))
.await
.expect("Should be able to read the migration directory.");
let m = m.set_locking(true);
m.run(&pool)
.await
.expect("Should be able to run the migration.");
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let m = Migrator::new(std::path::Path::new("./migrations"));
let mut m = rt.block_on(m).unwrap();
let m = m.set_locking(false);
rt.block_on(m.run(&pool)).unwrap();
tracing::info!("Ran migrations");
}
@ -120,7 +129,7 @@ mod tests {
#[tokio::test]
async fn it_migrates_the_db() {
let db = super::get_db_pool().await;
let db = super::get_db_pool();
let r = sqlx::query("select count(*) from users")
.fetch_one(&db)
.await;
@ -373,7 +382,13 @@ mod session_store {
use super::*;
async fn test_store() -> SqliteSessionStore {
fn test_store() -> SqliteSessionStore {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = rt.block_on(async {
let store = SqliteSessionStore::new("sqlite::memory:")
.await
.expect("building a sqlite :memory: SqliteSessionStore");
@ -382,14 +397,25 @@ mod session_store {
.await
.expect("migrating a brand new :memory: SqliteSessionStore");
store
});
dbg!("got the store");
store
}
#[tokio::test]
async fn creating_a_new_session_with_no_expiry() -> Result {
let store = test_store().await;
#[test]
fn creating_a_new_session_with_no_expiry() -> Result {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new();
dbg!("new session");
session.insert("key", "value")?;
let cloned = session.clone();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
@ -411,15 +437,22 @@ mod session_store {
assert!(!loaded_session.is_expired());
Ok(())
})
}
#[tokio::test]
async fn updating_a_session() -> Result {
let store = test_store().await;
#[test]
fn updating_a_session() -> Result {
let store = test_store();
let mut session = Session::new();
let original_id = session.id().to_owned();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
session.insert("key", "value")?;
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
@ -438,15 +471,23 @@ mod session_store {
assert_eq!(original_id, id);
Ok(())
})
}
#[tokio::test]
async fn updating_a_session_extending_expiry() -> Result {
let store = test_store().await;
#[test]
fn updating_a_session_extending_expiry() -> Result {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new();
session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
@ -468,16 +509,23 @@ mod session_store {
assert_eq!(original_id, id);
Ok(())
})
}
#[tokio::test]
async fn creating_a_new_session_with_expiry() -> Result {
let store = test_store().await;
#[test]
fn creating_a_new_session_with_expiry() -> Result {
let store = test_store();
let mut session = Session::new();
session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?;
let cloned = session.clone();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
@ -503,11 +551,18 @@ mod session_store {
assert_eq!(None, store.load_session(cookie_value).await?);
Ok(())
})
}
#[tokio::test]
async fn destroying_a_single_session() -> Result {
let store = test_store().await;
#[test]
fn destroying_a_single_session() -> Result {
let store = test_store();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
@ -522,11 +577,18 @@ mod session_store {
// // attempting to destroy the session again is not an error
// assert!(store.destroy_session(session).await.is_ok());
Ok(())
})
}
#[tokio::test]
async fn clearing_the_whole_store() -> Result {
let store = test_store().await;
#[test]
fn clearing_the_whole_store() -> Result {
let store = test_store();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
@ -536,6 +598,7 @@ mod session_store {
assert_eq!(0, store.count().await?);
Ok(())
})
}
}
}

View file

@ -20,33 +20,46 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
#[cfg(test)]
mod test {
use std::future::IntoFuture;
use axum_test::TestServer;
use tokio::runtime::Runtime;
use crate::db;
#[tokio::test]
async fn slash_is_ok() {
let pool = db::get_db_pool().await;
#[test]
fn slash_is_ok() {
let pool = db::get_db_pool();
let secret = [0u8; 64];
let app = crate::app(pool.clone(), &secret).await.into_make_service();
let rt = Runtime::new().unwrap();
let app = rt
.block_on(crate::app(pool.clone(), &secret))
.into_make_service();
let server = TestServer::new(app).unwrap();
server.get("/").await.assert_status_ok();
rt.block_on(server.get("/").into_future())
.assert_status_ok();
}
#[tokio::test]
async fn not_found_is_303() {
let pool = db::get_db_pool().await;
#[test]
fn not_found_is_303() {
let pool = db::get_db_pool();
let secret = [0u8; 64];
let app = crate::app(pool, &secret).await.into_make_service();
let rt = Runtime::new().unwrap();
let app = rt
.block_on(crate::app(pool.clone(), &secret))
.into_make_service();
let server = TestServer::new(app).unwrap();
assert_eq!(
rt.block_on(
server
.get("/no-actual-route")
.expect_failure()
.await
.into_future()
)
.status_code(),
303
);

View file

@ -158,7 +158,7 @@ mod test {
#[tokio::test]
async fn ensure_omega_user() {
let p = crate::db::get_db_pool().await;
let p = crate::db::get_db_pool();
assert!(!check_omega_exists(&p).await);
ensure_omega(&p).await;
assert!(check_omega_exists(&p).await);

View file

@ -105,6 +105,10 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse {
LogoutSuccessPage
}
//-************************************************************************
// tests
//-************************************************************************
#[cfg(test)]
mod test {
use crate::{
@ -114,20 +118,32 @@ mod test {
const LOGIN_FORM: &str = "username=test_user&password=a";
#[tokio::test]
async fn get_login() {
let s = server().await;
#[test]
fn get_login() {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let s = server();
rt.block_on(async {
let resp = s.get("/login").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LoginPage::default().to_string());
})
}
#[tokio::test]
async fn post_login_success() {
let s = server().await;
#[test]
fn post_login_success() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let body = massage(LOGIN_FORM);
rt.block_on(async {
let resp = s
.post("/login")
.expect_failure()
@ -135,15 +151,21 @@ mod test {
.bytes(body)
.await;
assert_eq!(resp.status_code(), 303);
})
}
#[tokio::test]
async fn post_login_bad_user() {
let s = server().await;
#[test]
fn post_login_bad_user() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let form = "username=test_LOSER&password=aaaa";
let body = massage(form);
rt.block_on(async {
let resp = s
.post("/login")
.expect_success()
@ -151,15 +173,21 @@ mod test {
.bytes(body)
.await;
assert_eq!(resp.status_code(), 200);
})
}
#[tokio::test]
async fn post_login_bad_password() {
let s = server().await;
#[test]
fn post_login_bad_password() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let form = "username=test_user&password=bbbb";
let body = massage(form);
rt.block_on(async {
let resp = s
.post("/login")
.expect_success()
@ -167,32 +195,51 @@ mod test {
.bytes(body)
.await;
assert_eq!(resp.status_code(), 200);
})
}
#[tokio::test]
async fn get_logout() {
let s = server().await;
#[test]
fn get_logout() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let resp = s.get("/logout").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LogoutPage.to_string());
})
}
#[tokio::test]
async fn post_logout_not_logged_in() {
let s = server().await;
#[test]
fn post_logout_not_logged_in() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let resp = s.post("/logout").await;
resp.assert_status_ok();
let body = std::str::from_utf8(resp.bytes()).unwrap();
let default = LogoutSuccessPage.to_string();
assert_eq!(body, &default);
})
}
#[tokio::test]
async fn post_logout_logged_in() {
let s = server().await;
#[test]
fn post_logout_logged_in() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
// log in and prove it
{
rt.block_on(async {
let body = massage(LOGIN_FORM);
let resp = s
.post("/login")
@ -210,11 +257,11 @@ mod test {
let main_page = s.get("/").await;
let body = std::str::from_utf8(main_page.bytes()).unwrap();
assert_eq!(&logged_in, body);
}
let resp = s.post("/logout").await;
let body = std::str::from_utf8(resp.bytes()).unwrap();
let default = LogoutSuccessPage.to_string();
assert_eq!(body, &default);
})
}
}

View file

@ -1,12 +1,12 @@
use std::net::SocketAddr;
use rand::{thread_rng, RngCore};
use sqlx::SqlitePool;
use tokio::signal;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use what2watch::get_db_pool;
#[tokio::main]
async fn main() {
fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
@ -15,8 +15,18 @@ async fn main() {
.with(tracing_subscriber::fmt::layer())
.init();
let pool = get_db_pool().await;
let pool = get_db_pool();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(runner(&pool));
rt.block_on(pool.close());
}
async fn runner(pool: &SqlitePool) {
let secret = {
let mut bytes = [0u8; 64];
let mut rng = thread_rng();
@ -34,8 +44,6 @@ async fn main() {
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap_or_default();
pool.close().await;
}
async fn shutdown_signal() {

View file

@ -223,7 +223,7 @@ mod test {
#[tokio::test]
async fn post_create_user() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM);
@ -243,7 +243,7 @@ mod test {
#[tokio::test]
async fn get_create_user() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let resp = server.get("/signup").await;
@ -254,7 +254,7 @@ mod test {
#[tokio::test]
async fn handle_signup_success() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let user = get_test_user();
@ -290,7 +290,7 @@ mod test {
#[tokio::test]
async fn password_mismatch() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_MISMATCH_FORM);
@ -313,7 +313,7 @@ mod test {
#[tokio::test]
async fn password_short() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_SHORT_FORM);
@ -336,7 +336,7 @@ mod test {
#[tokio::test]
async fn password_long() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_LONG_FORM);
@ -363,7 +363,7 @@ mod test {
// min length is 4
assert_eq!(pw.len(), 4);
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let form =
format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}");
@ -388,7 +388,7 @@ mod test {
#[tokio::test]
async fn username_short() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(USERNAME_SHORT_FORM);
@ -411,7 +411,7 @@ mod test {
#[tokio::test]
async fn username_long() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(USERNAME_LONG_FORM);
@ -434,7 +434,7 @@ mod test {
#[tokio::test]
async fn username_duplicate() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM);
@ -465,7 +465,7 @@ mod test {
#[tokio::test]
async fn displayname_long() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let body = massage(DISPLAYNAME_LONG_FORM);
@ -488,7 +488,7 @@ mod test {
#[tokio::test]
async fn handle_signup_success() {
let pool = get_db_pool().await;
let pool = get_db_pool();
let server = server_with_pool(&pool).await;
let path = format!("/signup_success/nope");

View file

@ -17,11 +17,17 @@ pub fn get_test_user() -> User {
}
}
pub async fn server() -> TestServer {
let pool = crate::db::get_db_pool().await;
pub fn server() -> TestServer {
let pool = crate::db::get_db_pool();
let secret = [0u8; 64];
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let user = get_test_user();
rt.block_on(async {
sqlx::query(crate::signup::CREATE_QUERY)
.bind(user.id)
.bind(&user.username)
@ -44,6 +50,7 @@ pub async fn server() -> TestServer {
..Default::default()
};
TestServer::new_with_config(app, config).unwrap()
})
}
pub async fn server_with_pool(pool: &SqlitePool) -> TestServer {