From 453a126b952128ff8765830c8f8d7eeb11b188e2 Mon Sep 17 00:00:00 2001 From: Joe Ardent Date: Wed, 31 May 2023 15:58:03 -0700 Subject: [PATCH] Add generic handlers tests, remove broken doctests for session store. --- Cargo.lock | 33 +++++++++++ Cargo.toml | 3 + src/db.rs | 128 ++++------------------------------------ src/generic_handlers.rs | 37 +++++++++++- 4 files changed, 84 insertions(+), 117 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 507f70a..6acaf5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,6 +158,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "auto-future" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c1e7e457ea78e524f48639f551fd79703ac3f2237f5ecccdf4708f8a75ad373" + [[package]] name = "autocfg" version = "1.1.0" @@ -290,6 +296,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-test" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a01b0885dcea3124d990b24fd5ad9d0735f77d53385161665a0d9fd08a03e6" +dependencies = [ + "anyhow", + "auto-future", + "axum", + "cookie", + "hyper", + "portpicker", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "base64" version = "0.13.1" @@ -1258,6 +1281,15 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "portpicker" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9" +dependencies = [ + "rand", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -2345,6 +2377,7 @@ dependencies = [ "axum", "axum-login", "axum-macros", + "axum-test", "justerror", "password-hash", "rand_core", diff --git a/Cargo.toml b/Cargo.toml index 6d27745..800803f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,6 @@ unicode-segmentation = "1" urlencoding = "2" async-session = "3" +[dev-dependencies] +axum-test = "9.0.0" + diff --git a/src/db.rs b/src/db.rs index 9f3c158..c43a77d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -21,12 +21,21 @@ const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u pub async fn get_pool() -> SqlitePool { let db_filename = { std::env::var("DATABASE_FILE").unwrap_or_else(|_| { - let home = - std::env::var("HOME").expect("Could not determine $HOME for finding db file"); - format!("{home}/.witch-watch.db") + #[cfg(not(test))] + { + let home = + std::env::var("HOME").expect("Could not determine $HOME for finding db file"); + format!("{home}/.witch-watch.db") + } + #[cfg(test)] + { + ":memory:".to_string() + } }) }; + dbg!(&db_filename); + let conn_opts = SqliteConnectOptions::new() .foreign_keys(true) .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) @@ -104,25 +113,6 @@ mod session_store { */ /// sqlx sqlite session store for async-sessions - /// - /// ```rust - /// use witch_watch::session_store::SqliteSessionStore; - /// use async_session::{Session, SessionStore, Result}; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// - /// let mut session = Session::new(); - /// session.insert("key", vec![1,2,3]); - /// - /// let cookie_value = store.store_session(session).await?.unwrap(); - /// let session = store.load_session(cookie_value).await?.unwrap(); - /// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]); - /// # Ok(()) } - #[derive(Clone, Debug)] pub struct SqliteSessionStore { client: SqlitePool, @@ -134,18 +124,6 @@ mod session_store { /// sqlx::SqlitePool. the default table name for this session /// store will be "async_sessions". To override this, chain this /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name). - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); - /// let store = SqliteSessionStore::from_client(pool) - /// .with_table_name("custom_table_name"); - /// store.migrate().await; - /// # Ok(()) } - /// ``` pub fn from_client(client: SqlitePool) -> Self { Self { client, @@ -163,16 +141,6 @@ mod session_store { /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or /// use /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await; - /// # Ok(()) } - /// ``` pub async fn new(database_url: &str) -> sqlx::Result { Ok(Self::from_client(SqlitePool::connect(database_url).await?)) } @@ -183,16 +151,6 @@ mod session_store { /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or /// use /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?; - /// store.migrate().await; - /// # Ok(()) } - /// ``` pub async fn new_with_table_name( database_url: &str, table_name: &str, @@ -202,26 +160,6 @@ mod session_store { /// Chainable method to add a custom table name. This will panic /// if the table name is not `[a-zA-Z0-9_-]+`. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await? - /// .with_table_name("custom_name"); - /// store.migrate().await; - /// # Ok(()) } - /// ``` - /// - /// ```should_panic - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await? - /// .with_table_name("johnny (); drop users;"); - /// # Ok(()) } - /// ``` pub fn with_table_name(mut self, table_name: impl AsRef) -> Self { let table_name = table_name.as_ref(); if table_name.is_empty() @@ -244,19 +182,6 @@ mod session_store { /// store initialization. In the future, this may make /// exactly-once modifications to the schema of the session table /// on breaking releases. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// assert!(store.count().await.is_err()); - /// store.migrate().await?; - /// store.store_session(Session::new()).await?; - /// store.migrate().await?; // calling it a second time is safe - /// assert_eq!(store.count().await?, 1); - /// # Ok(()) } - /// ``` pub async fn migrate(&self) -> sqlx::Result<()> { log::info!("migrating sessions on `{}`", self.table_name); @@ -288,21 +213,6 @@ mod session_store { /// Performs a one-time cleanup task that clears out stale /// (expired) sessions. You may want to call this from cron. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session}; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// let mut session = Session::new(); - /// session.set_expiry(Utc::now() - Duration::seconds(5)); - /// store.store_session(session).await?; - /// assert_eq!(store.count().await?, 1); - /// store.cleanup().await?; - /// assert_eq!(store.count().await?, 0); - /// # Ok(()) } - /// ``` pub async fn cleanup(&self) -> sqlx::Result<()> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( @@ -320,20 +230,6 @@ mod session_store { /// retrieves the number of sessions currently stored, including /// expired sessions - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # use std::time::Duration; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// assert_eq!(store.count().await?, 0); - /// store.store_session(Session::new()).await?; - /// assert_eq!(store.count().await?, 1); - /// # Ok(()) } - /// ``` pub async fn count(&self) -> sqlx::Result { let (count,) = sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%")) diff --git a/src/generic_handlers.rs b/src/generic_handlers.rs index 8791922..5d2c89c 100644 --- a/src/generic_handlers.rs +++ b/src/generic_handlers.rs @@ -3,7 +3,7 @@ use axum::response::{IntoResponse, Redirect}; use crate::{templates::Index, AuthContext}; pub async fn handle_slash_redir() -> impl IntoResponse { - Redirect::temporary("/") + Redirect::to("/") } pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse { @@ -17,3 +17,38 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse { user: auth.current_user, } } + +#[cfg(test)] +mod tests { + use axum_test::TestServer; + + use crate::db; + + #[tokio::test] + async fn slash_is_ok() { + let pool = db::get_pool().await; + let secret = [0u8; 64]; + let app = crate::app(pool.clone(), &secret).await.into_make_service(); + + let server = TestServer::new(app).unwrap(); + + server.get("/").await.assert_status_ok(); + } + + #[tokio::test] + async fn not_found_is_303() { + let pool = db::get_pool().await; + let secret = [0u8; 64]; + let app = crate::app(pool, &secret).await.into_make_service(); + + let server = TestServer::new(app).unwrap(); + assert_eq!( + server + .get("/no-actual-route") + .expect_failure() + .await + .status_code(), + 303 + ); + } +}