diff --git a/Cargo.lock b/Cargo.lock index 507f70a..6acaf5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,6 +158,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "auto-future" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c1e7e457ea78e524f48639f551fd79703ac3f2237f5ecccdf4708f8a75ad373" + [[package]] name = "autocfg" version = "1.1.0" @@ -290,6 +296,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-test" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a01b0885dcea3124d990b24fd5ad9d0735f77d53385161665a0d9fd08a03e6" +dependencies = [ + "anyhow", + "auto-future", + "axum", + "cookie", + "hyper", + "portpicker", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "base64" version = "0.13.1" @@ -1258,6 +1281,15 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "portpicker" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9" +dependencies = [ + "rand", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -2345,6 +2377,7 @@ dependencies = [ "axum", "axum-login", "axum-macros", + "axum-test", "justerror", "password-hash", "rand_core", diff --git a/Cargo.toml b/Cargo.toml index 6d27745..800803f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,6 @@ unicode-segmentation = "1" urlencoding = "2" async-session = "3" +[dev-dependencies] +axum-test = "9.0.0" + diff --git a/src/db.rs b/src/db.rs index 9f3c158..7358ff1 100644 --- a/src/db.rs +++ b/src/db.rs @@ -7,6 +7,7 @@ use axum_login::{ }; use session_store::SqliteSessionStore; use sqlx::{ + migrate::Migrator, sqlite::{SqliteConnectOptions, SqlitePoolOptions}, SqlitePool, }; @@ -15,30 +16,62 @@ use uuid::Uuid; use crate::User; const MAX_CONNS: u32 = 100; -const TIMEOUT: u64 = 5; +const MIN_CONNS: u32 = 10; +const TIMEOUT: u64 = 11; const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64); pub async fn get_pool() -> SqlitePool { let db_filename = { std::env::var("DATABASE_FILE").unwrap_or_else(|_| { - let home = - std::env::var("HOME").expect("Could not determine $HOME for finding db file"); - format!("{home}/.witch-watch.db") + #[cfg(not(test))] + { + let home = + std::env::var("HOME").expect("Could not determine $HOME for finding db file"); + format!("{home}/.witch-watch.db") + } + #[cfg(test)] + { + use rand_core::RngCore; + let mut rng = rand_core::OsRng; + let id = rng.next_u64(); + format!("file:testdb-{id}?mode=memory&cache=shared") + } }) }; + tracing::info!("Connecting to DB at {db_filename}"); + let conn_opts = SqliteConnectOptions::new() .foreign_keys(true) .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) .filename(&db_filename) - .busy_timeout(Duration::from_secs(TIMEOUT)); + .busy_timeout(Duration::from_secs(TIMEOUT)) + .create_if_missing(true); - // setup connection pool - SqlitePoolOptions::new() + let pool = SqlitePoolOptions::new() .max_connections(MAX_CONNS) + .min_connections(MIN_CONNS) + .idle_timeout(Some(Duration::from_secs(10))) + .max_lifetime(Some(Duration::from_secs(3600))) .connect_with(conn_opts) .await - .expect("can't connect to database") + .expect("can't connect to database"); + + { + let mut m = Migrator::new(std::path::Path::new("./migrations")) + .await + .expect("Should be able to read the migration directory."); + + let m = m.set_locking(true); + + m.run(&pool) + .await + .expect("Should be able to run the migration."); + + tracing::info!("Ran migrations"); + } + + pool } pub async fn session_layer(pool: SqlitePool, secret: &[u8]) -> SessionLayer { @@ -71,6 +104,26 @@ pub async fn auth_layer( AuthLayer::new(store, secret) } +//-************************************************************************ +// Tests for `db` module. +//-************************************************************************ +#[cfg(test)] +mod tests { + + #[tokio::test] + async fn it_migrates_the_db() { + let db = super::get_pool().await; + let r = sqlx::query("select count(*) from witches") + .fetch_one(&db) + .await; + assert!(r.is_ok()); + } +} + +//-************************************************************************ +// End public interface. +//-************************************************************************ + //-************************************************************************ // Session store sub-module, not a public lib. //-************************************************************************ @@ -104,25 +157,6 @@ mod session_store { */ /// sqlx sqlite session store for async-sessions - /// - /// ```rust - /// use witch_watch::session_store::SqliteSessionStore; - /// use async_session::{Session, SessionStore, Result}; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// - /// let mut session = Session::new(); - /// session.insert("key", vec![1,2,3]); - /// - /// let cookie_value = store.store_session(session).await?.unwrap(); - /// let session = store.load_session(cookie_value).await?.unwrap(); - /// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]); - /// # Ok(()) } - #[derive(Clone, Debug)] pub struct SqliteSessionStore { client: SqlitePool, @@ -134,18 +168,6 @@ mod session_store { /// sqlx::SqlitePool. the default table name for this session /// store will be "async_sessions". To override this, chain this /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name). - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); - /// let store = SqliteSessionStore::from_client(pool) - /// .with_table_name("custom_table_name"); - /// store.migrate().await; - /// # Ok(()) } - /// ``` pub fn from_client(client: SqlitePool) -> Self { Self { client, @@ -163,16 +185,6 @@ mod session_store { /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or /// use /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await; - /// # Ok(()) } - /// ``` pub async fn new(database_url: &str) -> sqlx::Result { Ok(Self::from_client(SqlitePool::connect(database_url).await?)) } @@ -183,16 +195,6 @@ mod session_store { /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or /// use /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?; - /// store.migrate().await; - /// # Ok(()) } - /// ``` pub async fn new_with_table_name( database_url: &str, table_name: &str, @@ -202,26 +204,6 @@ mod session_store { /// Chainable method to add a custom table name. This will panic /// if the table name is not `[a-zA-Z0-9_-]+`. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await? - /// .with_table_name("custom_name"); - /// store.migrate().await; - /// # Ok(()) } - /// ``` - /// - /// ```should_panic - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await? - /// .with_table_name("johnny (); drop users;"); - /// # Ok(()) } - /// ``` pub fn with_table_name(mut self, table_name: impl AsRef) -> Self { let table_name = table_name.as_ref(); if table_name.is_empty() @@ -244,19 +226,6 @@ mod session_store { /// store initialization. In the future, this may make /// exactly-once modifications to the schema of the session table /// on breaking releases. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// assert!(store.count().await.is_err()); - /// store.migrate().await?; - /// store.store_session(Session::new()).await?; - /// store.migrate().await?; // calling it a second time is safe - /// assert_eq!(store.count().await?, 1); - /// # Ok(()) } - /// ``` pub async fn migrate(&self) -> sqlx::Result<()> { log::info!("migrating sessions on `{}`", self.table_name); @@ -288,21 +257,6 @@ mod session_store { /// Performs a one-time cleanup task that clears out stale /// (expired) sessions. You may want to call this from cron. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session}; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// let mut session = Session::new(); - /// session.set_expiry(Utc::now() - Duration::seconds(5)); - /// store.store_session(session).await?; - /// assert_eq!(store.count().await?, 1); - /// store.cleanup().await?; - /// assert_eq!(store.count().await?, 0); - /// # Ok(()) } - /// ``` pub async fn cleanup(&self) -> sqlx::Result<()> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( @@ -320,20 +274,6 @@ mod session_store { /// retrieves the number of sessions currently stored, including /// expired sessions - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # use std::time::Duration; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// assert_eq!(store.count().await?, 0); - /// store.store_session(Session::new()).await?; - /// assert_eq!(store.count().await?, 1); - /// # Ok(()) } - /// ``` pub async fn count(&self) -> sqlx::Result { let (count,) = sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%")) diff --git a/src/generic_handlers.rs b/src/generic_handlers.rs index 8791922..09e0cd1 100644 --- a/src/generic_handlers.rs +++ b/src/generic_handlers.rs @@ -3,7 +3,7 @@ use axum::response::{IntoResponse, Redirect}; use crate::{templates::Index, AuthContext}; pub async fn handle_slash_redir() -> impl IntoResponse { - Redirect::temporary("/") + Redirect::to("/") } pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse { @@ -17,3 +17,38 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse { user: auth.current_user, } } + +#[cfg(test)] +mod test { + use axum_test::TestServer; + + use crate::db; + + #[tokio::test] + async fn slash_is_ok() { + let pool = db::get_pool().await; + let secret = [0u8; 64]; + let app = crate::app(pool.clone(), &secret).await.into_make_service(); + + let server = TestServer::new(app).unwrap(); + + server.get("/").await.assert_status_ok(); + } + + #[tokio::test] + async fn not_found_is_303() { + let pool = db::get_pool().await; + let secret = [0u8; 64]; + let app = crate::app(pool, &secret).await.into_make_service(); + + let server = TestServer::new(app).unwrap(); + assert_eq!( + server + .get("/no-actual-route") + .expect_failure() + .await + .status_code(), + 303 + ); + } +} diff --git a/src/login.rs b/src/login.rs index 3530840..c89c00b 100644 --- a/src/login.rs +++ b/src/login.rs @@ -39,11 +39,12 @@ pub enum LoginErrorKind { impl IntoResponse for LoginError { fn into_response(self) -> Response { match self.0 { - LoginErrorKind::Unknown | LoginErrorKind::Internal => ( + LoginErrorKind::Internal => ( StatusCode::INTERNAL_SERVER_ERROR, "An unknown error occurred; you cursed, brah?", ) .into_response(), + LoginErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(), _ => (StatusCode::OK, format!("{self}")).into_response(), } } @@ -79,7 +80,7 @@ pub async fn post_login( .await .map_err(|_| LoginErrorKind::Internal)?; - Ok(Redirect::temporary("/")) + Ok(Redirect::to("/")) } _ => Err(LoginErrorKind::BadPassword.into()), } @@ -99,3 +100,103 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse { } LogoutPost } + +#[cfg(test)] +mod test { + use std::time::Duration; + + use axum::body::Bytes; + use axum_test::TestServer; + + use crate::{ + db, + signup::create_user, + templates::{LoginGet, LogoutGet, LogoutPost}, + }; + + async fn tserver() -> TestServer { + let pool = db::get_pool().await; + let secret = [0u8; 64]; + + tokio::time::sleep(Duration::from_secs(2)).await; + + let _user = create_user( + "test_user", + &Some("Test User".to_string()), + &Some("mail@email".to_string()), + "aaaa".as_bytes(), + &pool, + ) + .await + .unwrap(); + + let r = sqlx::query("select count(*) from witches") + .fetch_one(&pool) + .await; + assert!(r.is_ok()); + + let app = crate::app(pool, &secret).await.into_make_service(); + + TestServer::new(app).unwrap() + } + + #[tokio::test] + async fn get_login() { + let s = tserver().await; + let resp = s.get("/login").await; + let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); + assert_eq!(body, LoginGet::default().to_string()); + } + + #[tokio::test] + async fn post_login_success() { + let s = tserver().await; + + let form = "username=test_user&password=aaaa".to_string(); + let bytes = form.as_bytes(); + let body = Bytes::copy_from_slice(bytes); + + let resp = s + .post("/login") + .expect_failure() + .content_type("application/x-www-form-urlencoded") + .bytes(body) + .await; + assert_eq!(resp.status_code(), 303); + } + + #[tokio::test] + async fn post_login_bad_user() { + let s = tserver().await; + + let form = "username=test_LOSER&password=aaaa".to_string(); + let bytes = form.as_bytes(); + let body = Bytes::copy_from_slice(bytes); + + let resp = s + .post("/login") + .expect_success() + .content_type("application/x-www-form-urlencoded") + .bytes(body) + .await; + assert_eq!(resp.status_code(), 200); + } + + #[tokio::test] + async fn get_logout() { + let s = tserver().await; + let resp = s.get("/logout").await; + let body = std::str::from_utf8(resp.bytes()).unwrap().to_string(); + assert_eq!(body, LogoutGet.to_string()); + } + + #[tokio::test] + async fn post_logout() { + let s = tserver().await; + let resp = s.post("/logout").await; + resp.assert_status_ok(); + let body = std::str::from_utf8(resp.bytes()).unwrap(); + let default = LogoutPost.to_string(); + assert_eq!(body, &default); + } +} diff --git a/src/signup.rs b/src/signup.rs index f32a47f..0fc696e 100644 --- a/src/signup.rs +++ b/src/signup.rs @@ -78,8 +78,11 @@ pub async fn post_create_user( let email = &signup.email; let password = &signup.password; let verify = &signup.pw_verify; - let username = username.trim(); + let username = urlencoding::decode(username) + .map_err(|_| CreateUserErrorKind::BadUsername)? + .to_string(); + let username = username.trim(); let name_len = username.graphemes(true).size_hint().1.unwrap(); // we are not ascii exclusivists around here if !(1..=20).contains(&name_len) { @@ -165,7 +168,7 @@ pub async fn handle_signup_success( // private fns //-************************************************************************ -async fn create_user( +pub(crate) async fn create_user( username: &str, displayname: &Option, email: &Option, @@ -181,14 +184,21 @@ async fn create_user( .to_string(); let id = Uuid::new_v4(); - let res = sqlx::query(CREATE_QUERY) + let query = sqlx::query(CREATE_QUERY) .bind(id) .bind(username) .bind(displayname) .bind(email) - .bind(&pwhash) - .execute(pool) - .await; + .bind(&pwhash); + + let res = { + let txn = pool.begin().await.expect("Could not beign transaction"); + let r = query.execute(pool).await; + txn.commit() + .await + .expect("Should be able to commit transaction"); + r + }; match res { Ok(_) => { diff --git a/src/templates.rs b/src/templates.rs index a40b25d..45ed268 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::User; -#[derive(Debug, Default, Template, Deserialize, Serialize)] +#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "signup.html")] pub struct CreateUser { pub username: String, @@ -13,29 +13,29 @@ pub struct CreateUser { pub pw_verify: String, } -#[derive(Debug, Default, Template, Deserialize, Serialize)] +#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "login_post.html")] pub struct LoginPost { pub username: String, pub password: String, } -#[derive(Debug, Default, Template, Deserialize, Serialize)] +#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "login_get.html")] pub struct LoginGet { pub username: String, pub password: String, } -#[derive(Debug, Default, Template, Deserialize, Serialize)] +#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "logout_get.html")] pub struct LogoutGet; -#[derive(Debug, Default, Template, Deserialize, Serialize)] +#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "logout_post.html")] pub struct LogoutPost; -#[derive(Debug, Default, Template, Deserialize, Serialize)] +#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)] #[template(path = "index.html")] pub struct Index { pub user: Option,