Fix the tests.

The test server had to be created in an async context or Hyper wouldn't run.
This commit is contained in:
Joe Ardent 2023-07-21 15:15:47 -07:00
parent 48308aa169
commit 359a732a84
8 changed files with 419 additions and 454 deletions

View file

@ -126,14 +126,18 @@ pub async fn auth_layer(
//-************************************************************************ //-************************************************************************
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use tokio::runtime::Runtime;
#[tokio::test] #[test]
async fn it_migrates_the_db() { fn it_migrates_the_db() {
let rt = Runtime::new().unwrap();
let db = super::get_db_pool(); let db = super::get_db_pool();
rt.block_on(async {
let r = sqlx::query("select count(*) from users") let r = sqlx::query("select count(*) from users")
.fetch_one(&db) .fetch_one(&db)
.await; .await;
assert!(r.is_ok()); assert!(r.is_ok());
});
} }
} }
@ -382,13 +386,7 @@ mod session_store {
use super::*; use super::*;
fn test_store() -> SqliteSessionStore { async fn test_store() -> SqliteSessionStore {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let store = rt.block_on(async {
let store = SqliteSessionStore::new("sqlite::memory:") let store = SqliteSessionStore::new("sqlite::memory:")
.await .await
.expect("building a sqlite :memory: SqliteSessionStore"); .expect("building a sqlite :memory: SqliteSessionStore");
@ -397,25 +395,15 @@ mod session_store {
.await .await
.expect("migrating a brand new :memory: SqliteSessionStore"); .expect("migrating a brand new :memory: SqliteSessionStore");
store store
});
dbg!("got the store");
store
} }
#[test] #[tokio::test]
fn creating_a_new_session_with_no_expiry() -> Result { async fn creating_a_new_session_with_no_expiry() -> Result {
let rt = tokio::runtime::Builder::new_multi_thread() let store = test_store().await;
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
dbg!("new session");
session.insert("key", "value")?; session.insert("key", "value")?;
let cloned = session.clone(); let cloned = session.clone();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) = let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
@ -437,22 +425,16 @@ mod session_store {
assert!(!loaded_session.is_expired()); assert!(!loaded_session.is_expired());
Ok(()) Ok(())
})
} }
#[test] #[tokio::test]
fn updating_a_session() -> Result { async fn updating_a_session() -> Result {
let store = test_store(); let store = test_store().await;
let mut session = Session::new(); let mut session = Session::new();
let original_id = session.id().to_owned(); let original_id = session.id().to_owned();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
session.insert("key", "value")?; session.insert("key", "value")?;
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
@ -471,23 +453,16 @@ mod session_store {
assert_eq!(original_id, id); assert_eq!(original_id, id);
Ok(()) Ok(())
})
} }
#[test] #[tokio::test]
fn updating_a_session_extending_expiry() -> Result { async fn updating_a_session_extending_expiry() -> Result {
let rt = tokio::runtime::Builder::new_multi_thread() let store = test_store().await;
.enable_all()
.build()
.unwrap();
let store = test_store();
let mut session = Session::new(); let mut session = Session::new();
session.expire_in(Duration::from_secs(10)); session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned(); let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone(); let original_expires = session.expiry().unwrap().clone();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
@ -509,23 +484,16 @@ mod session_store {
assert_eq!(original_id, id); assert_eq!(original_id, id);
Ok(()) Ok(())
})
} }
#[test] #[tokio::test]
fn creating_a_new_session_with_expiry() -> Result { async fn creating_a_new_session_with_expiry() -> Result {
let store = test_store(); let store = test_store().await;
let mut session = Session::new(); let mut session = Session::new();
session.expire_in(Duration::from_secs(1)); session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?; session.insert("key", "value")?;
let cloned = session.clone(); let cloned = session.clone();
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let cookie_value = store.store_session(session).await?.unwrap(); let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) = let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
@ -551,18 +519,11 @@ mod session_store {
assert_eq!(None, store.load_session(cookie_value).await?); assert_eq!(None, store.load_session(cookie_value).await?);
Ok(()) Ok(())
})
} }
#[test] #[tokio::test]
fn destroying_a_single_session() -> Result { async fn destroying_a_single_session() -> Result {
let store = test_store(); let store = test_store().await;
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
for _ in 0..3i8 { for _ in 0..3i8 {
store.store_session(Session::new()).await?; store.store_session(Session::new()).await?;
} }
@ -577,18 +538,11 @@ mod session_store {
// // attempting to destroy the session again is not an error // // attempting to destroy the session again is not an error
// assert!(store.destroy_session(session).await.is_ok()); // assert!(store.destroy_session(session).await.is_ok());
Ok(()) Ok(())
})
} }
#[test] #[tokio::test]
fn clearing_the_whole_store() -> Result { async fn clearing_the_whole_store() -> Result {
let store = test_store(); let store = test_store().await;
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
for _ in 0..3i8 { for _ in 0..3i8 {
store.store_session(Session::new()).await?; store.store_session(Session::new()).await?;
} }
@ -598,7 +552,6 @@ mod session_store {
assert_eq!(0, store.count().await?); assert_eq!(0, store.count().await?);
Ok(()) Ok(())
})
} }
} }
} }

View file

@ -20,46 +20,33 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::future::IntoFuture;
use axum_test::TestServer;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use crate::db; use crate::{get_db_pool, test_utils::server_with_pool};
#[test] #[test]
fn slash_is_ok() { fn slash_is_ok() {
let pool = db::get_db_pool(); let db = get_db_pool();
let secret = [0u8; 64];
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
let app = rt rt.block_on(async {
.block_on(crate::app(pool.clone(), &secret)) let server = server_with_pool(&db).await;
.into_make_service(); server.get("/").await
})
let server = TestServer::new(app).unwrap();
rt.block_on(server.get("/").into_future())
.assert_status_ok(); .assert_status_ok();
} }
#[test] #[test]
fn not_found_is_303() { fn not_found_is_303() {
let pool = db::get_db_pool();
let secret = [0u8; 64];
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
let app = rt
.block_on(crate::app(pool.clone(), &secret))
.into_make_service();
let server = TestServer::new(app).unwrap(); let db = get_db_pool();
assert_eq!( assert_eq!(
rt.block_on( rt.block_on(async {
server let server = server_with_pool(&db).await;
.get("/no-actual-route") server.get("/no-actual-route").expect_failure().await
.expect_failure() })
.into_future()
)
.status_code(), .status_code(),
303 303
); );

View file

@ -154,13 +154,18 @@ async fn check_omega_exists(db_pool: &SqlitePool) -> bool {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use tokio::runtime::Runtime;
use super::*; use super::*;
#[tokio::test] #[test]
async fn ensure_omega_user() { fn ensure_omega_user() {
let p = crate::db::get_db_pool(); let p = crate::db::get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
assert!(!check_omega_exists(&p).await); assert!(!check_omega_exists(&p).await);
ensure_omega(&p).await; ensure_omega(&p).await;
assert!(check_omega_exists(&p).await); assert!(check_omega_exists(&p).await);
});
} }
} }

View file

@ -1,3 +1,5 @@
use axum::routing::IntoMakeService;
use sqlx::SqlitePool;
#[macro_use] #[macro_use]
extern crate justerror; extern crate justerror;
@ -9,10 +11,11 @@ pub mod test_utils;
pub use db::get_db_pool; pub use db::get_db_pool;
pub use db_id::DbId; pub use db_id::DbId;
pub mod import_utils; pub mod import_utils;
pub use users::User; pub use users::User;
pub use watches::{ShowKind, Watch, WatchQuest}; pub use watches::{ShowKind, Watch, WatchQuest};
pub type WWRouter = axum::Router<SqlitePool>;
// everything else is private to the crate // everything else is private to the crate
mod db; mod db;
mod db_id; mod db_id;
@ -32,11 +35,8 @@ use watches::templates::*;
type AuthContext = axum_login::extractors::AuthContext<DbId, User, axum_login::SqliteStore<User>>; type AuthContext = axum_login::extractors::AuthContext<DbId, User, axum_login::SqliteStore<User>>;
/// Returns the router to be used as a service or test object, you do you. /// Returns the router to be used as a service or test object, you do you.
pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Router { pub async fn app(db_pool: sqlx::SqlitePool, secret: &[u8]) -> IntoMakeService<axum::Router> {
use axum::{middleware, routing::get}; use axum::{middleware, routing::get};
let session_layer = db::session_layer(db_pool.clone(), session_secret).await;
let auth_layer = db::auth_layer(db_pool.clone(), session_secret).await;
// don't bother bringing handlers into the whole crate namespace // don't bother bringing handlers into the whole crate namespace
use generic_handlers::{handle_slash, handle_slash_redir}; use generic_handlers::{handle_slash, handle_slash_redir};
use login::{get_login, get_logout, post_login, post_logout}; use login::{get_login, get_logout, post_login, post_logout};
@ -46,6 +46,12 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout
post_add_watch_quest, post_add_watch_quest,
}; };
let (session_layer, auth_layer) = {
let session_layer = db::session_layer(db_pool.clone(), secret).await;
let auth_layer = db::auth_layer(db_pool.clone(), secret).await;
(session_layer, auth_layer)
};
axum::Router::new() axum::Router::new()
.route("/", get(handle_slash).post(handle_slash)) .route("/", get(handle_slash).post(handle_slash))
.route("/signup", get(get_create_user).post(post_create_user)) .route("/signup", get(get_create_user).post(post_create_user))
@ -69,6 +75,7 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout
.layer(auth_layer) .layer(auth_layer)
.layer(session_layer) .layer(session_layer)
.with_state(db_pool) .with_state(db_pool)
.into_make_service()
} }
//-************************************************************************ //-************************************************************************

View file

@ -112,8 +112,9 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ use crate::{
get_db_pool,
templates::{LoginPage, LogoutPage, LogoutSuccessPage, MainPage}, templates::{LoginPage, LogoutPage, LogoutSuccessPage, MainPage},
test_utils::{get_test_user, massage, server, FORM_CONTENT_TYPE}, test_utils::{get_test_user, massage, server_with_pool, FORM_CONTENT_TYPE},
}; };
const LOGIN_FORM: &str = "username=test_user&password=a"; const LOGIN_FORM: &str = "username=test_user&password=a";
@ -125,8 +126,9 @@ mod test {
.build() .build()
.unwrap(); .unwrap();
let s = server(); let db = get_db_pool();
rt.block_on(async { rt.block_on(async {
let s = server_with_pool(&db).await;
let resp = s.get("/login").await; let resp = s.get("/login").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LoginPage::default().to_string()); assert_eq!(body, LoginPage::default().to_string());
@ -135,7 +137,6 @@ mod test {
#[test] #[test]
fn post_login_success() { fn post_login_success() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread() let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
@ -143,7 +144,9 @@ mod test {
let body = massage(LOGIN_FORM); let body = massage(LOGIN_FORM);
let db = get_db_pool();
rt.block_on(async { rt.block_on(async {
let s = server_with_pool(&db).await;
let resp = s let resp = s
.post("/login") .post("/login")
.expect_failure() .expect_failure()
@ -156,7 +159,6 @@ mod test {
#[test] #[test]
fn post_login_bad_user() { fn post_login_bad_user() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread() let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
@ -165,7 +167,9 @@ mod test {
let form = "username=test_LOSER&password=aaaa"; let form = "username=test_LOSER&password=aaaa";
let body = massage(form); let body = massage(form);
let db = get_db_pool();
rt.block_on(async { rt.block_on(async {
let s = server_with_pool(&db).await;
let resp = s let resp = s
.post("/login") .post("/login")
.expect_success() .expect_success()
@ -178,7 +182,6 @@ mod test {
#[test] #[test]
fn post_login_bad_password() { fn post_login_bad_password() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread() let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
@ -187,7 +190,9 @@ mod test {
let form = "username=test_user&password=bbbb"; let form = "username=test_user&password=bbbb";
let body = massage(form); let body = massage(form);
let db = get_db_pool();
rt.block_on(async { rt.block_on(async {
let s = server_with_pool(&db).await;
let resp = s let resp = s
.post("/login") .post("/login")
.expect_success() .expect_success()
@ -200,13 +205,14 @@ mod test {
#[test] #[test]
fn get_logout() { fn get_logout() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread() let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let db = get_db_pool();
rt.block_on(async { rt.block_on(async {
let s = server_with_pool(&db).await;
let resp = s.get("/logout").await; let resp = s.get("/logout").await;
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
assert_eq!(body, LogoutPage.to_string()); assert_eq!(body, LogoutPage.to_string());
@ -215,13 +221,14 @@ mod test {
#[test] #[test]
fn post_logout_not_logged_in() { fn post_logout_not_logged_in() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread() let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let db = get_db_pool();
rt.block_on(async { rt.block_on(async {
let s = server_with_pool(&db).await;
let resp = s.post("/logout").await; let resp = s.post("/logout").await;
resp.assert_status_ok(); resp.assert_status_ok();
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
@ -232,14 +239,15 @@ mod test {
#[test] #[test]
fn post_logout_logged_in() { fn post_logout_logged_in() {
let s = server();
let rt = tokio::runtime::Builder::new_multi_thread() let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
// log in and prove it // log in and prove it
let db = get_db_pool();
rt.block_on(async { rt.block_on(async {
let s = server_with_pool(&db).await;
let body = massage(LOGIN_FORM); let body = massage(LOGIN_FORM);
let resp = s let resp = s
.post("/login") .post("/login")

View file

@ -1,7 +1,6 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use rand::{thread_rng, RngCore}; use rand::{thread_rng, RngCore};
use sqlx::SqlitePool;
use tokio::signal; use tokio::signal;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use what2watch::get_db_pool; use what2watch::get_db_pool;
@ -22,11 +21,6 @@ fn main() {
.build() .build()
.unwrap(); .unwrap();
rt.block_on(runner(&pool));
rt.block_on(pool.close());
}
async fn runner(pool: &SqlitePool) {
let secret = { let secret = {
let mut bytes = [0u8; 64]; let mut bytes = [0u8; 64];
let mut rng = thread_rng(); let mut rng = thread_rng();
@ -34,16 +28,20 @@ async fn runner(pool: &SqlitePool) {
bytes bytes
}; };
let app = what2watch::app(pool.clone(), &secret).await; let app = rt.block_on(what2watch::app(pool.clone(), &secret));
rt.block_on(async {
let addr: SocketAddr = ([0, 0, 0, 0], 3000).into(); let addr: SocketAddr = ([0, 0, 0, 0], 3000).into();
tracing::debug!("binding to {addr:?}"); tracing::debug!("binding to {addr:?}");
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(app.into_make_service()) .serve(app)
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await
.unwrap_or_default(); .unwrap_or_default();
});
rt.block_on(pool.close());
} }
async fn shutdown_signal() { async fn shutdown_signal() {

View file

@ -211,19 +211,22 @@ pub(crate) async fn create_user(
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use axum::http::StatusCode; use axum::http::StatusCode;
use tokio::runtime::Runtime;
use crate::{ use crate::{
db::get_db_pool, db::get_db_pool,
templates::{SignupPage, SignupSuccessPage}, templates::{SignupPage, SignupSuccessPage},
test_utils::{get_test_user, insert_user, massage, server_with_pool, FORM_CONTENT_TYPE}, test_utils::{get_test_user, massage, server_with_pool, FORM_CONTENT_TYPE},
User, User,
}; };
const GOOD_FORM: &str = "username=test_user&displayname=Test+User&password=aaaa&pw_verify=aaaa"; const GOOD_FORM: &str = "username=good_user&displayname=Test+User&password=aaaa&pw_verify=aaaa";
#[tokio::test] #[test]
async fn post_create_user() { fn post_create_user() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM); let body = massage(GOOD_FORM);
@ -237,28 +240,33 @@ mod test {
assert_eq!(StatusCode::SEE_OTHER, resp.status_code()); assert_eq!(StatusCode::SEE_OTHER, resp.status_code());
// get the new user from the db // get the new user from the db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("good_user", &pool).await;
assert!(user.is_ok()); assert!(user.is_ok());
});
} }
#[tokio::test] #[test]
async fn get_create_user() { fn get_create_user() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let resp = server.get("/signup").await; let resp = server.get("/signup").await;
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = SignupPage::default().to_string(); let expected = SignupPage::default().to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn handle_signup_success() { fn handle_signup_success() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let user = get_test_user(); let user = get_test_user();
insert_user(&user, &pool).await;
let id = user.id.0.to_string(); let id = user.id.0.to_string();
let path = format!("/signup_success/{id}"); let path = format!("/signup_success/{id}");
@ -267,6 +275,7 @@ mod test {
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = SignupSuccessPage(user).to_string(); let expected = SignupSuccessPage(user).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
//-************************************************************************ //-************************************************************************
@ -278,19 +287,21 @@ mod test {
// various ways to fuck up signup // various ways to fuck up signup
const PASSWORD_MISMATCH_FORM: &str = const PASSWORD_MISMATCH_FORM: &str =
"username=test_user&displayname=Test+User&password=aaaa&pw_verify=bbbb"; "username=bad_user&displayname=Test+User&password=aaaa&pw_verify=bbbb";
const PASSWORD_SHORT_FORM: &str = const PASSWORD_SHORT_FORM: &str =
"username=test_user&displayname=Test+User&password=a&pw_verify=a"; "username=bad_user&displayname=Test+User&password=a&pw_verify=a";
const PASSWORD_LONG_FORM: &str = "username=test_user&displayname=Test+User&password=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda&pw_verify=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda"; const PASSWORD_LONG_FORM: &str = "username=bad_user&displayname=Test+User&password=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda&pw_verify=sphinx+of+black+qwartz+judge+my+vow+etc+etc+yadd+yadda";
const USERNAME_SHORT_FORM: &str = const USERNAME_SHORT_FORM: &str =
"username=&displayname=Test+User&password=aaaa&pw_verify=aaaa"; "username=&displayname=Test+User&password=aaaa&pw_verify=aaaa";
const USERNAME_LONG_FORM: &str = const USERNAME_LONG_FORM: &str =
"username=test_user12345678901234567890&displayname=Test+User&password=aaaa&pw_verify=aaaa"; "username=bad_user12345678901234567890&displayname=Test+User&password=aaaa&pw_verify=aaaa";
const DISPLAYNAME_LONG_FORM: &str = "username=test_user&displayname=Since+time+immemorial%2C+display+names+have+been+subject+to+a+number+of+conventions%2C+restrictions%2C+usages%2C+and+even+incentives.+Have+we+finally+gone+too+far%3F+In+this+essay%2C+&password=aaaa&pw_verify=aaaa"; const DISPLAYNAME_LONG_FORM: &str = "username=test_user&displayname=Since+time+immemorial%2C+display+names+have+been+subject+to+a+number+of+conventions%2C+restrictions%2C+usages%2C+and+even+incentives.+Have+we+finally+gone+too+far%3F+In+this+essay%2C+&password=aaaa&pw_verify=aaaa";
#[tokio::test] #[test]
async fn password_mismatch() { fn password_mismatch() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_MISMATCH_FORM); let body = massage(PASSWORD_MISMATCH_FORM);
@ -303,17 +314,20 @@ mod test {
.await; .await;
// no user in db // no user in db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_err());
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string(); let expected = CreateUserError(CreateUserErrorKind::PasswordMismatch).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn password_short() { fn password_short() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_SHORT_FORM); let body = massage(PASSWORD_SHORT_FORM);
@ -326,17 +340,20 @@ mod test {
.await; .await;
// no user in db // no user in db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_err());
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn password_long() { fn password_long() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(PASSWORD_LONG_FORM); let body = massage(PASSWORD_LONG_FORM);
@ -349,24 +366,29 @@ mod test {
.await; .await;
// no user in db // no user in db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_err());
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn multibyte_password_too_short() { fn multibyte_password_too_short() {
let pw = "🤡"; let pw = "🤡";
// min length is 4 // min length is 4
assert_eq!(pw.len(), 4); assert_eq!(pw.len(), 4);
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let form = let form = format!(
format!("username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}"); "username=test_user&displayname=Test+User&password={pw}&pw_verify={pw}"
);
let body = massage(&form); let body = massage(&form);
let resp = server let resp = server
@ -378,17 +400,20 @@ mod test {
.await; .await;
// no user in db // no user in db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_err());
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadPassword).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn username_short() { fn username_short() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(USERNAME_SHORT_FORM); let body = massage(USERNAME_SHORT_FORM);
@ -401,17 +426,20 @@ mod test {
.await; .await;
// no user in db // no user in db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_err());
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn username_long() { fn username_long() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(USERNAME_LONG_FORM); let body = massage(USERNAME_LONG_FORM);
@ -424,17 +452,20 @@ mod test {
.await; .await;
// no user in db // no user in db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_err());
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadUsername).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn username_duplicate() { fn username_duplicate() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(GOOD_FORM); let body = massage(GOOD_FORM);
@ -446,7 +477,7 @@ mod test {
.await; .await;
// get the new user from the db // get the new user from the db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("good_user", &pool).await;
assert!(user.is_ok()); assert!(user.is_ok());
// now try again // now try again
@ -461,11 +492,14 @@ mod test {
let expected = CreateUserError(CreateUserErrorKind::AlreadyExists).to_string(); let expected = CreateUserError(CreateUserErrorKind::AlreadyExists).to_string();
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn displayname_long() { fn displayname_long() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let body = massage(DISPLAYNAME_LONG_FORM); let body = massage(DISPLAYNAME_LONG_FORM);
@ -478,23 +512,28 @@ mod test {
.await; .await;
// no user in db // no user in db
let user = User::try_get("test_user", &pool).await; let user = User::try_get("bad_user", &pool).await;
assert!(user.is_err()); assert!(user.is_err());
let body = std::str::from_utf8(resp.bytes()).unwrap(); let body = std::str::from_utf8(resp.bytes()).unwrap();
let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string(); let expected = CreateUserError(CreateUserErrorKind::BadDisplayname).to_string();
assert_eq!(&expected, body); assert_eq!(&expected, body);
});
} }
#[tokio::test] #[test]
async fn handle_signup_success() { fn handle_signup_success() {
let pool = get_db_pool(); let pool = get_db_pool();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let server = server_with_pool(&pool).await; let server = server_with_pool(&pool).await;
let path = format!("/signup_success/nope"); let path = format!("/signup_success/nope");
let resp = server.get(&path).expect_failure().await; let resp = server.get(&path).expect_failure().await;
assert_eq!(resp.status_code(), StatusCode::SEE_OTHER); assert_eq!(resp.status_code(), StatusCode::SEE_OTHER);
});
} }
} }
} }

View file

@ -17,51 +17,19 @@ pub fn get_test_user() -> User {
} }
} }
pub fn server() -> TestServer {
let pool = crate::db::get_db_pool();
let secret = [0u8; 64];
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let user = get_test_user();
rt.block_on(async {
sqlx::query(crate::signup::CREATE_QUERY)
.bind(user.id)
.bind(&user.username)
.bind(&user.displayname)
.bind(&user.email)
.bind(&user.pwhash)
.execute(&pool)
.await
.unwrap();
let r = sqlx::query("select count(*) from users")
.fetch_one(&pool)
.await;
assert!(r.is_ok());
let app = crate::app(pool, &secret).await.into_make_service();
let config = TestServerConfig {
save_cookies: true,
..Default::default()
};
TestServer::new_with_config(app, config).unwrap()
})
}
pub async fn server_with_pool(pool: &SqlitePool) -> TestServer { pub async fn server_with_pool(pool: &SqlitePool) -> TestServer {
let secret = [0u8; 64]; let secret = [0u8; 64];
let r = sqlx::query("select count(*) from users") let user = get_test_user();
.fetch_one(pool)
.await;
assert!(r.is_ok());
let app = crate::app(pool.clone(), &secret).await.into_make_service(); insert_user(&user, pool).await;
let r: i32 = sqlx::query_scalar("select count(*) from users")
.fetch_one(pool)
.await
.unwrap_or_default();
assert!(r == 1);
let app = crate::app(pool.clone(), &secret).await;
let config = TestServerConfig { let config = TestServerConfig {
save_cookies: true, save_cookies: true,