diff --git a/src/db.rs b/src/db.rs index e3ccc45..9f3c158 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,12 +1,22 @@ use std::time::Duration; +use async_session::SessionStore; +use axum_login::{ + axum_sessions::{PersistencePolicy, SessionLayer}, + AuthLayer, SqliteStore, SqlxStore, +}; +use session_store::SqliteSessionStore; use sqlx::{ sqlite::{SqliteConnectOptions, SqlitePoolOptions}, SqlitePool, }; +use uuid::Uuid; + +use crate::User; const MAX_CONNS: u32 = 100; const TIMEOUT: u64 = 5; +const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64); pub async fn get_pool() -> SqlitePool { let db_filename = { @@ -30,3 +40,553 @@ pub async fn get_pool() -> SqlitePool { .await .expect("can't connect to database") } + +pub async fn session_layer(pool: SqlitePool, secret: &[u8]) -> SessionLayer { + let store = session_store::SqliteSessionStore::from_client(pool); + store + .migrate() + .await + .expect("Calling `migrate()` should be reliable, is the DB gone?"); + + // since the secret is new every time the server starts, old sessions won't be + // valid anymore; if there were ever more than one service host or there were + // managed secrets, this would need to go away. + store + .clear_store() + .await + .unwrap_or_else(|e| tracing::error!("Could not delete old sessions; got error: {e}")); + + SessionLayer::new(store, secret) + .with_secure(true) + .with_session_ttl(Some(SESSION_TTL)) + .with_persistence_policy(PersistencePolicy::ExistingOnly) +} + +pub async fn auth_layer( + pool: SqlitePool, + secret: &[u8], +) -> AuthLayer, Uuid, User> { + const QUERY: &str = "select * from witches where id = $1"; + let store = SqliteStore::::new(pool).with_query(QUERY); + AuthLayer::new(store, secret) +} + +//-************************************************************************ +// Session store sub-module, not a public lib. +//-************************************************************************ +mod session_store { + use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session}; + use sqlx::{pool::PoolConnection, Sqlite}; + + use super::*; + + // NOTE! This code was straight stolen from + // https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs + // and used under the terms of the MIT license: + + /* + Copyright 2022 Jacob Rothstein + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and + associated documentation files (the “Software”), to deal in the Software without restriction, + including without limitation the rights to use, copy, modify, merge, publish, distribute, + sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial + portions of the Software. + + THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT + NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES + OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + /// sqlx sqlite session store for async-sessions + /// + /// ```rust + /// use witch_watch::session_store::SqliteSessionStore; + /// use async_session::{Session, SessionStore, Result}; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new("sqlite::memory:").await?; + /// store.migrate().await?; + /// + /// let mut session = Session::new(); + /// session.insert("key", vec![1,2,3]); + /// + /// let cookie_value = store.store_session(session).await?.unwrap(); + /// let session = store.load_session(cookie_value).await?.unwrap(); + /// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]); + /// # Ok(()) } + + #[derive(Clone, Debug)] + pub struct SqliteSessionStore { + client: SqlitePool, + table_name: String, + } + + impl SqliteSessionStore { + /// constructs a new SqliteSessionStore from an existing + /// sqlx::SqlitePool. the default table name for this session + /// store will be "async_sessions". To override this, chain this + /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name). + /// + /// ```rust + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::Result; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); + /// let store = SqliteSessionStore::from_client(pool) + /// .with_table_name("custom_table_name"); + /// store.migrate().await; + /// # Ok(()) } + /// ``` + pub fn from_client(client: SqlitePool) -> Self { + Self { + client, + table_name: "async_sessions".into(), + } + } + + /// Constructs a new SqliteSessionStore from a sqlite: database url. + /// note that this documentation uses the special `:memory:` + /// sqlite database for convenient testing, but a real + /// application would use a path like + /// `sqlite:///path/to/database.db`. The default table name for + /// this session store will be "async_sessions". To + /// override this, either chain with + /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or + /// use + /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) + /// + /// ```rust + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::Result; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new("sqlite::memory:").await?; + /// store.migrate().await; + /// # Ok(()) } + /// ``` + pub async fn new(database_url: &str) -> sqlx::Result { + Ok(Self::from_client(SqlitePool::connect(database_url).await?)) + } + + /// constructs a new SqliteSessionStore from a sqlite: database url. the + /// default table name for this session store will be + /// "async_sessions". To override this, either chain with + /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or + /// use + /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) + /// + /// ```rust + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::Result; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?; + /// store.migrate().await; + /// # Ok(()) } + /// ``` + pub async fn new_with_table_name( + database_url: &str, + table_name: &str, + ) -> sqlx::Result { + Ok(Self::new(database_url).await?.with_table_name(table_name)) + } + + /// Chainable method to add a custom table name. This will panic + /// if the table name is not `[a-zA-Z0-9_-]+`. + /// ```rust + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::Result; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new("sqlite::memory:").await? + /// .with_table_name("custom_name"); + /// store.migrate().await; + /// # Ok(()) } + /// ``` + /// + /// ```should_panic + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::Result; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new("sqlite::memory:").await? + /// .with_table_name("johnny (); drop users;"); + /// # Ok(()) } + /// ``` + pub fn with_table_name(mut self, table_name: impl AsRef) -> Self { + let table_name = table_name.as_ref(); + if table_name.is_empty() + || !table_name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') + { + panic!( + "table name must be [a-zA-Z0-9_-]+, but {} was not", + table_name + ); + } + + self.table_name = table_name.to_owned(); + self + } + + /// Creates a session table if it does not already exist. If it + /// does, this will noop, making it safe to call repeatedly on + /// store initialization. In the future, this may make + /// exactly-once modifications to the schema of the session table + /// on breaking releases. + /// ```rust + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::{Result, SessionStore, Session}; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new("sqlite::memory:").await?; + /// assert!(store.count().await.is_err()); + /// store.migrate().await?; + /// store.store_session(Session::new()).await?; + /// store.migrate().await?; // calling it a second time is safe + /// assert_eq!(store.count().await?, 1); + /// # Ok(()) } + /// ``` + pub async fn migrate(&self) -> sqlx::Result<()> { + log::info!("migrating sessions on `{}`", self.table_name); + + let mut conn = self.client.acquire().await?; + sqlx::query(&self.substitute_table_name( + r#" + CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% ( + id TEXT PRIMARY KEY NOT NULL, + expires INTEGER NULL, + session TEXT NOT NULL + ) + "#, + )) + .execute(&mut conn) + .await?; + Ok(()) + } + + // private utility function because sqlite does not support + // parametrized table names + fn substitute_table_name(&self, query: &str) -> String { + query.replace("%%TABLE_NAME%%", &self.table_name) + } + + /// retrieve a connection from the pool + async fn connection(&self) -> sqlx::Result> { + self.client.acquire().await + } + + /// Performs a one-time cleanup task that clears out stale + /// (expired) sessions. You may want to call this from cron. + /// ```rust + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session}; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new("sqlite::memory:").await?; + /// store.migrate().await?; + /// let mut session = Session::new(); + /// session.set_expiry(Utc::now() - Duration::seconds(5)); + /// store.store_session(session).await?; + /// assert_eq!(store.count().await?, 1); + /// store.cleanup().await?; + /// assert_eq!(store.count().await?, 0); + /// # Ok(()) } + /// ``` + pub async fn cleanup(&self) -> sqlx::Result<()> { + let mut connection = self.connection().await?; + sqlx::query(&self.substitute_table_name( + r#" + DELETE FROM %%TABLE_NAME%% + WHERE expires < ? + "#, + )) + .bind(Utc::now().timestamp()) + .execute(&mut connection) + .await?; + + Ok(()) + } + + /// retrieves the number of sessions currently stored, including + /// expired sessions + /// + /// ```rust + /// # use witch_watch::session_store::SqliteSessionStore; + /// # use async_session::{Result, SessionStore, Session}; + /// # use std::time::Duration; + /// # #[tokio::main] + /// # async fn main() -> Result { + /// let store = SqliteSessionStore::new("sqlite::memory:").await?; + /// store.migrate().await?; + /// assert_eq!(store.count().await?, 0); + /// store.store_session(Session::new()).await?; + /// assert_eq!(store.count().await?, 1); + /// # Ok(()) } + /// ``` + pub async fn count(&self) -> sqlx::Result { + let (count,) = + sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%")) + .fetch_one(&mut self.connection().await?) + .await?; + + Ok(count) + } + } + + #[async_trait] + impl SessionStore for SqliteSessionStore { + async fn load_session(&self, cookie_value: String) -> Result> { + let id = Session::id_from_cookie_value(&cookie_value)?; + let mut connection = self.connection().await?; + + let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name( + r#" + SELECT session FROM %%TABLE_NAME%% + WHERE id = ? AND (expires IS NULL OR expires > ?) + "#, + )) + .bind(&id) + .bind(Utc::now().timestamp()) + .fetch_optional(&mut connection) + .await?; + + Ok(result + .map(|(session,)| serde_json::from_str(&session)) + .transpose()?) + } + + async fn store_session(&self, session: Session) -> Result> { + let id = session.id(); + let string = serde_json::to_string(&session)?; + let mut connection = self.connection().await?; + + sqlx::query(&self.substitute_table_name( + r#" + INSERT INTO %%TABLE_NAME%% + (id, session, expires) VALUES (?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + expires = excluded.expires, + session = excluded.session + "#, + )) + .bind(id) + .bind(&string) + .bind(session.expiry().map(|expiry| expiry.timestamp())) + .execute(&mut connection) + .await?; + + Ok(session.into_cookie_value()) + } + + async fn destroy_session(&self, session: Session) -> Result { + let id = session.id(); + let mut connection = self.connection().await?; + sqlx::query(&self.substitute_table_name( + r#" + DELETE FROM %%TABLE_NAME%% WHERE id = ? + "#, + )) + .bind(id) + .execute(&mut connection) + .await?; + + Ok(()) + } + + async fn clear_store(&self) -> Result { + let mut connection = self.connection().await?; + sqlx::query(&self.substitute_table_name( + r#" + DELETE FROM %%TABLE_NAME%% + "#, + )) + .execute(&mut connection) + .await?; + + Ok(()) + } + } + + #[cfg(test)] + mod tests { + use std::time::Duration; + + use super::*; + + async fn test_store() -> SqliteSessionStore { + let store = SqliteSessionStore::new("sqlite::memory:") + .await + .expect("building a sqlite :memory: SqliteSessionStore"); + store + .migrate() + .await + .expect("migrating a brand new :memory: SqliteSessionStore"); + store + } + + #[tokio::test] + async fn creating_a_new_session_with_no_expiry() -> Result { + let store = test_store().await; + let mut session = Session::new(); + session.insert("key", "value")?; + let cloned = session.clone(); + let cookie_value = store.store_session(session).await?.unwrap(); + + let (id, expires, serialized, count): (String, Option, String, i64) = + sqlx::query_as("select id, expires, session, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; + + assert_eq!(1, count); + assert_eq!(id, cloned.id()); + assert_eq!(expires, None); + + let deserialized_session: Session = serde_json::from_str(&serialized)?; + assert_eq!(cloned.id(), deserialized_session.id()); + assert_eq!("value", &deserialized_session.get::("key").unwrap()); + + let loaded_session = store.load_session(cookie_value).await?.unwrap(); + assert_eq!(cloned.id(), loaded_session.id()); + assert_eq!("value", &loaded_session.get::("key").unwrap()); + + assert!(!loaded_session.is_expired()); + Ok(()) + } + + #[tokio::test] + async fn updating_a_session() -> Result { + let store = test_store().await; + let mut session = Session::new(); + let original_id = session.id().to_owned(); + + session.insert("key", "value")?; + let cookie_value = store.store_session(session).await?.unwrap(); + + let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + session.insert("key", "other value")?; + assert_eq!(None, store.store_session(session).await?); + + let session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.get::("key").unwrap(), "other value"); + + let (id, count): (String, i64) = + sqlx::query_as("select id, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; + + assert_eq!(1, count); + assert_eq!(original_id, id); + + Ok(()) + } + + #[tokio::test] + async fn updating_a_session_extending_expiry() -> Result { + let store = test_store().await; + let mut session = Session::new(); + session.expire_in(Duration::from_secs(10)); + let original_id = session.id().to_owned(); + let original_expires = session.expiry().unwrap().clone(); + let cookie_value = store.store_session(session).await?.unwrap(); + + let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.expiry().unwrap(), &original_expires); + session.expire_in(Duration::from_secs(20)); + let new_expires = session.expiry().unwrap().clone(); + store.store_session(session).await?; + + let session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(session.expiry().unwrap(), &new_expires); + + let (id, expires, count): (String, i64, i64) = + sqlx::query_as("select id, expires, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; + + assert_eq!(1, count); + assert_eq!(expires, new_expires.timestamp()); + assert_eq!(original_id, id); + + Ok(()) + } + + #[tokio::test] + async fn creating_a_new_session_with_expiry() -> Result { + let store = test_store().await; + let mut session = Session::new(); + session.expire_in(Duration::from_secs(1)); + session.insert("key", "value")?; + let cloned = session.clone(); + + let cookie_value = store.store_session(session).await?.unwrap(); + + let (id, expires, serialized, count): (String, Option, String, i64) = + sqlx::query_as("select id, expires, session, count(*) from async_sessions") + .fetch_one(&mut store.connection().await?) + .await?; + + assert_eq!(1, count); + assert_eq!(id, cloned.id()); + assert!(expires.unwrap() > Utc::now().timestamp()); + + let deserialized_session: Session = serde_json::from_str(&serialized)?; + assert_eq!(cloned.id(), deserialized_session.id()); + assert_eq!("value", &deserialized_session.get::("key").unwrap()); + + let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert_eq!(cloned.id(), loaded_session.id()); + assert_eq!("value", &loaded_session.get::("key").unwrap()); + + assert!(!loaded_session.is_expired()); + + tokio::time::sleep(Duration::from_secs(1)).await; + assert_eq!(None, store.load_session(cookie_value).await?); + + Ok(()) + } + + #[tokio::test] + async fn destroying_a_single_session() -> Result { + let store = test_store().await; + for _ in 0..3i8 { + store.store_session(Session::new()).await?; + } + + let cookie = store.store_session(Session::new()).await?.unwrap(); + assert_eq!(4, store.count().await?); + let session = store.load_session(cookie.clone()).await?.unwrap(); + store.destroy_session(session.clone()).await.unwrap(); + assert_eq!(None, store.load_session(cookie).await?); + assert_eq!(3, store.count().await?); + + // // attempting to destroy the session again is not an error + // assert!(store.destroy_session(session).await.is_ok()); + Ok(()) + } + + #[tokio::test] + async fn clearing_the_whole_store() -> Result { + let store = test_store().await; + for _ in 0..3i8 { + store.store_session(Session::new()).await?; + } + + assert_eq!(3, store.count().await?); + store.clear_store().await.unwrap(); + assert_eq!(0, store.count().await?); + + Ok(()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index bc43291..eca0693 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,6 @@ use uuid::Uuid; pub mod db; pub mod generic_handlers; pub mod login; -pub mod session_store; pub mod signup; pub(crate) mod templates; pub mod users; diff --git a/src/login.rs b/src/login.rs index 225b216..3530840 100644 --- a/src/login.rs +++ b/src/login.rs @@ -20,8 +20,6 @@ use crate::{ // Constants //-************************************************************************ -const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1"; - //-************************************************************************ // Login error and success types //-************************************************************************ @@ -68,7 +66,7 @@ pub async fn post_login( let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?; let pw = pw.trim(); - let user = User::get(username, &pool) + let user = User::try_get(username, &pool) .await .map_err(|_| LoginErrorKind::Unknown)?; @@ -81,13 +79,6 @@ pub async fn post_login( .await .map_err(|_| LoginErrorKind::Internal)?; - // update last_seen; maybe this is ok to fail? - sqlx::query(LAST_SEEN_QUERY) - .bind(user.id) - .execute(&pool) - .await - .map_err(|_| LoginErrorKind::Internal)?; - Ok(Redirect::temporary("/")) } _ => Err(LoginErrorKind::BadPassword.into()), diff --git a/src/main.rs b/src/main.rs index 35ab102..3ee3e4b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,14 @@ -use std::{net::SocketAddr, time::Duration}; +use std::net::SocketAddr; -use axum::{routing::get, Router}; -use axum_login::{ - axum_sessions::{PersistencePolicy, SessionLayer}, - AuthLayer, SqliteStore, -}; +use axum::{middleware, routing::get, Router}; use rand_core::{OsRng, RngCore}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use witch_watch::{ db, generic_handlers::{handle_slash, handle_slash_redir}, login::{get_login, get_logout, post_login, post_logout}, - session_store::SqliteSessionStore, signup::{get_create_user, handle_signup_success, post_create_user}, - User, + users, }; #[tokio::main] @@ -35,20 +30,8 @@ async fn main() { bytes }; - let session_layer = { - let store = SqliteSessionStore::from_client(pool.clone()); - store.migrate().await.expect("Could not migrate session DB"); - SessionLayer::new(store, &secret) - .with_secure(true) - .with_persistence_policy(PersistencePolicy::ExistingOnly) - .with_session_ttl(Some(Duration::from_secs(3600 * 24 * 366))) - }; - - let auth_layer = { - const QUERY: &str = "select * from witches where id = $1"; - let store = SqliteStore::::new(pool.clone()).with_query(QUERY); - AuthLayer::new(store, &secret) - }; + let session_layer = db::session_layer(pool.clone(), &secret).await; + let auth_layer = db::auth_layer(pool.clone(), &secret).await; let app = Router::new() .route("/", get(handle_slash).post(handle_slash)) @@ -60,12 +43,17 @@ async fn main() { .route("/login", get(get_login).post(post_login)) .route("/logout", get(get_logout).post(post_logout)) .fallback(handle_slash_redir) + .route_layer(middleware::from_fn_with_state( + pool.clone(), + users::handle_update_last_seen, + )) .layer(auth_layer) .layer(session_layer) .with_state(pool); + let addr = ([127, 0, 0, 1], 3000); tracing::debug!("binding to 0.0.0.0:3000"); - axum::Server::bind(&SocketAddr::from(([0, 0, 0, 0], 3000))) + axum::Server::bind(&SocketAddr::from(addr)) .serve(app.into_make_service()) .await .unwrap(); diff --git a/src/session_store.rs b/src/session_store.rs deleted file mode 100644 index b97768b..0000000 --- a/src/session_store.rs +++ /dev/null @@ -1,507 +0,0 @@ -use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore}; -use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite}; - -// NOTE! This code was straight stolen from -// https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs -// and used under the terms of the MIT license: - -/* -Copyright 2022 Jacob Rothstein - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -associated documentation files (the “Software”), to deal in the Software without restriction, -including without limitation the rights to use, copy, modify, merge, publish, distribute, -sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial -portions of the Software. - -THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT -NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES -OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -/// sqlx sqlite session store for async-sessions -/// -/// ```rust -/// use witch_watch::session_store::SqliteSessionStore; -/// use async_session::{Session, SessionStore, Result}; -/// use std::time::Duration; -/// -/// # #[tokio::main] -/// # async fn main() -> Result { -/// let store = SqliteSessionStore::new("sqlite::memory:").await?; -/// store.migrate().await?; -/// -/// let mut session = Session::new(); -/// session.insert("key", vec![1,2,3]); -/// -/// let cookie_value = store.store_session(session).await?.unwrap(); -/// let session = store.load_session(cookie_value).await?.unwrap(); -/// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]); -/// # Ok(()) } - -#[derive(Clone, Debug)] -pub struct SqliteSessionStore { - client: SqlitePool, - table_name: String, -} - -impl SqliteSessionStore { - /// constructs a new SqliteSessionStore from an existing - /// sqlx::SqlitePool. the default table name for this session - /// store will be "async_sessions". To override this, chain this - /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name). - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); - /// let store = SqliteSessionStore::from_client(pool) - /// .with_table_name("custom_table_name"); - /// store.migrate().await; - /// # Ok(()) } - /// ``` - pub fn from_client(client: SqlitePool) -> Self { - Self { - client, - table_name: "async_sessions".into(), - } - } - - /// Constructs a new SqliteSessionStore from a sqlite: database url. note - /// that this documentation uses the special `:memory:` sqlite - /// database for convenient testing, but a real application would - /// use a path like `sqlite:///path/to/database.db`. The default - /// table name for this session store will be "async_sessions". To - /// override this, either chain with - /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or - /// use - /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await; - /// # Ok(()) } - /// ``` - pub async fn new(database_url: &str) -> sqlx::Result { - Ok(Self::from_client(SqlitePool::connect(database_url).await?)) - } - - /// constructs a new SqliteSessionStore from a sqlite: database url. the - /// default table name for this session store will be - /// "async_sessions". To override this, either chain with - /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or - /// use - /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?; - /// store.migrate().await; - /// # Ok(()) } - /// ``` - pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result { - Ok(Self::new(database_url).await?.with_table_name(table_name)) - } - - /// Chainable method to add a custom table name. This will panic - /// if the table name is not `[a-zA-Z0-9_-]+`. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await? - /// .with_table_name("custom_name"); - /// store.migrate().await; - /// # Ok(()) } - /// ``` - /// - /// ```should_panic - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::Result; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await? - /// .with_table_name("johnny (); drop users;"); - /// # Ok(()) } - /// ``` - pub fn with_table_name(mut self, table_name: impl AsRef) -> Self { - let table_name = table_name.as_ref(); - if table_name.is_empty() - || !table_name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') - { - panic!( - "table name must be [a-zA-Z0-9_-]+, but {} was not", - table_name - ); - } - - self.table_name = table_name.to_owned(); - self - } - - /// Creates a session table if it does not already exist. If it - /// does, this will noop, making it safe to call repeatedly on - /// store initialization. In the future, this may make - /// exactly-once modifications to the schema of the session table - /// on breaking releases. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// assert!(store.count().await.is_err()); - /// store.migrate().await?; - /// store.store_session(Session::new()).await?; - /// store.migrate().await?; // calling it a second time is safe - /// assert_eq!(store.count().await?, 1); - /// # Ok(()) } - /// ``` - pub async fn migrate(&self) -> sqlx::Result<()> { - log::info!("migrating sessions on `{}`", self.table_name); - - let mut conn = self.client.acquire().await?; - sqlx::query(&self.substitute_table_name( - r#" - CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% ( - id TEXT PRIMARY KEY NOT NULL, - expires INTEGER NULL, - session TEXT NOT NULL - ) - "#, - )) - .execute(&mut conn) - .await?; - Ok(()) - } - - // private utility function because sqlite does not support - // parametrized table names - fn substitute_table_name(&self, query: &str) -> String { - query.replace("%%TABLE_NAME%%", &self.table_name) - } - - /// retrieve a connection from the pool - async fn connection(&self) -> sqlx::Result> { - self.client.acquire().await - } - - /// Performs a one-time cleanup task that clears out stale - /// (expired) sessions. You may want to call this from cron. - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session}; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// let mut session = Session::new(); - /// session.set_expiry(Utc::now() - Duration::seconds(5)); - /// store.store_session(session).await?; - /// assert_eq!(store.count().await?, 1); - /// store.cleanup().await?; - /// assert_eq!(store.count().await?, 0); - /// # Ok(()) } - /// ``` - pub async fn cleanup(&self) -> sqlx::Result<()> { - let mut connection = self.connection().await?; - sqlx::query(&self.substitute_table_name( - r#" - DELETE FROM %%TABLE_NAME%% - WHERE expires < ? - "#, - )) - .bind(Utc::now().timestamp()) - .execute(&mut connection) - .await?; - - Ok(()) - } - - /// retrieves the number of sessions currently stored, including - /// expired sessions - /// - /// ```rust - /// # use witch_watch::session_store::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # use std::time::Duration; - /// # #[tokio::main] - /// # async fn main() -> Result { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// assert_eq!(store.count().await?, 0); - /// store.store_session(Session::new()).await?; - /// assert_eq!(store.count().await?, 1); - /// # Ok(()) } - /// ``` - pub async fn count(&self) -> sqlx::Result { - let (count,) = - sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%")) - .fetch_one(&mut self.connection().await?) - .await?; - - Ok(count) - } -} - -#[async_trait] -impl SessionStore for SqliteSessionStore { - async fn load_session(&self, cookie_value: String) -> Result> { - let id = Session::id_from_cookie_value(&cookie_value)?; - let mut connection = self.connection().await?; - - let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name( - r#" - SELECT session FROM %%TABLE_NAME%% - WHERE id = ? AND (expires IS NULL OR expires > ?) - "#, - )) - .bind(&id) - .bind(Utc::now().timestamp()) - .fetch_optional(&mut connection) - .await?; - - Ok(result - .map(|(session,)| serde_json::from_str(&session)) - .transpose()?) - } - - async fn store_session(&self, session: Session) -> Result> { - let id = session.id(); - let string = serde_json::to_string(&session)?; - let mut connection = self.connection().await?; - - sqlx::query(&self.substitute_table_name( - r#" - INSERT INTO %%TABLE_NAME%% - (id, session, expires) VALUES (?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - expires = excluded.expires, - session = excluded.session - "#, - )) - .bind(id) - .bind(&string) - .bind(session.expiry().map(|expiry| expiry.timestamp())) - .execute(&mut connection) - .await?; - - Ok(session.into_cookie_value()) - } - - async fn destroy_session(&self, session: Session) -> Result { - let id = session.id(); - let mut connection = self.connection().await?; - sqlx::query(&self.substitute_table_name( - r#" - DELETE FROM %%TABLE_NAME%% WHERE id = ? - "#, - )) - .bind(id) - .execute(&mut connection) - .await?; - - Ok(()) - } - - async fn clear_store(&self) -> Result { - let mut connection = self.connection().await?; - sqlx::query(&self.substitute_table_name( - r#" - DELETE FROM %%TABLE_NAME%% - "#, - )) - .execute(&mut connection) - .await?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use super::*; - - async fn test_store() -> SqliteSessionStore { - let store = SqliteSessionStore::new("sqlite::memory:") - .await - .expect("building a sqlite :memory: SqliteSessionStore"); - store - .migrate() - .await - .expect("migrating a brand new :memory: SqliteSessionStore"); - store - } - - #[tokio::test] - async fn creating_a_new_session_with_no_expiry() -> Result { - let store = test_store().await; - let mut session = Session::new(); - session.insert("key", "value")?; - let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); - - let (id, expires, serialized, count): (String, Option, String, i64) = - sqlx::query_as("select id, expires, session, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; - - assert_eq!(1, count); - assert_eq!(id, cloned.id()); - assert_eq!(expires, None); - - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); - - let loaded_session = store.load_session(cookie_value).await?.unwrap(); - assert_eq!(cloned.id(), loaded_session.id()); - assert_eq!("value", &loaded_session.get::("key").unwrap()); - - assert!(!loaded_session.is_expired()); - Ok(()) - } - - #[tokio::test] - async fn updating_a_session() -> Result { - let store = test_store().await; - let mut session = Session::new(); - let original_id = session.id().to_owned(); - - session.insert("key", "value")?; - let cookie_value = store.store_session(session).await?.unwrap(); - - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - session.insert("key", "other value")?; - assert_eq!(None, store.store_session(session).await?); - - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.get::("key").unwrap(), "other value"); - - let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; - - assert_eq!(1, count); - assert_eq!(original_id, id); - - Ok(()) - } - - #[tokio::test] - async fn updating_a_session_extending_expiry() -> Result { - let store = test_store().await; - let mut session = Session::new(); - session.expire_in(Duration::from_secs(10)); - let original_id = session.id().to_owned(); - let original_expires = session.expiry().unwrap().clone(); - let cookie_value = store.store_session(session).await?.unwrap(); - - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &original_expires); - session.expire_in(Duration::from_secs(20)); - let new_expires = session.expiry().unwrap().clone(); - store.store_session(session).await?; - - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &new_expires); - - let (id, expires, count): (String, i64, i64) = - sqlx::query_as("select id, expires, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; - - assert_eq!(1, count); - assert_eq!(expires, new_expires.timestamp()); - assert_eq!(original_id, id); - - Ok(()) - } - - #[tokio::test] - async fn creating_a_new_session_with_expiry() -> Result { - let store = test_store().await; - let mut session = Session::new(); - session.expire_in(Duration::from_secs(1)); - session.insert("key", "value")?; - let cloned = session.clone(); - - let cookie_value = store.store_session(session).await?.unwrap(); - - let (id, expires, serialized, count): (String, Option, String, i64) = - sqlx::query_as("select id, expires, session, count(*) from async_sessions") - .fetch_one(&mut store.connection().await?) - .await?; - - assert_eq!(1, count); - assert_eq!(id, cloned.id()); - assert!(expires.unwrap() > Utc::now().timestamp()); - - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); - - let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(cloned.id(), loaded_session.id()); - assert_eq!("value", &loaded_session.get::("key").unwrap()); - - assert!(!loaded_session.is_expired()); - - tokio::time::sleep(Duration::from_secs(1)).await; - assert_eq!(None, store.load_session(cookie_value).await?); - - Ok(()) - } - - #[tokio::test] - async fn destroying_a_single_session() -> Result { - let store = test_store().await; - for _ in 0..3i8 { - store.store_session(Session::new()).await?; - } - - let cookie = store.store_session(Session::new()).await?.unwrap(); - assert_eq!(4, store.count().await?); - let session = store.load_session(cookie.clone()).await?.unwrap(); - store.destroy_session(session.clone()).await.unwrap(); - assert_eq!(None, store.load_session(cookie).await?); - assert_eq!(3, store.count().await?); - - // // attempting to destroy the session again is not an error - // assert!(store.destroy_session(session).await.is_ok()); - Ok(()) - } - - #[tokio::test] - async fn clearing_the_whole_store() -> Result { - let store = test_store().await; - for _ in 0..3i8 { - store.store_session(Session::new()).await?; - } - - assert_eq!(3, store.count().await?); - store.clear_store().await.unwrap(); - assert_eq!(0, store.count().await?); - - Ok(()) - } -} diff --git a/src/users.rs b/src/users.rs index baeaf2f..e54f04b 100644 --- a/src/users.rs +++ b/src/users.rs @@ -1,11 +1,15 @@ use std::fmt::Display; +use axum::{extract::State, http::Request, middleware::Next, response::IntoResponse}; use axum_login::{secrecy::SecretVec, AuthUser}; use serde::{Deserialize, Serialize}; use sqlx::SqlitePool; use uuid::Uuid; +use crate::AuthContext; + const USERNAME_QUERY: &str = "select * from witches where username = $1"; +const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1"; #[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)] pub struct User { @@ -41,10 +45,40 @@ impl AuthUser for User { } impl User { - pub async fn get(username: &str, db: &SqlitePool) -> Result { + pub async fn try_get(username: &str, db: &SqlitePool) -> Result { sqlx::query_as(USERNAME_QUERY) .bind(username) .fetch_one(db) .await } + + pub async fn update_last_seen(&self, pool: &SqlitePool) { + match sqlx::query(LAST_SEEN_QUERY) + .bind(self.id) + .execute(pool) + .await + { + Ok(_) => {} + Err(e) => { + let id = self.id.as_simple(); + tracing::error!("Could not update last_seen for user {id}; got {e:?}"); + } + } + } +} + +//-************************************************************************ +// User-specific middleware +//-************************************************************************ + +pub async fn handle_update_last_seen( + State(pool): State, + auth: AuthContext, + request: Request, + next: Next, +) -> impl IntoResponse { + if let Some(user) = auth.current_user { + user.update_last_seen(&pool).await; + } + next.run(request).await }