use std::time::Duration; use async_session::SessionStore; use axum_login::{ axum_sessions::{PersistencePolicy, SessionLayer}, AuthLayer, SqliteStore, SqlxStore, }; use session_store::SqliteSessionStore; use sqlx::{ migrate::Migrator, sqlite::{SqliteConnectOptions, SqlitePoolOptions}, SqlitePool, }; use crate::{db_id::DbId, User}; const MAX_CONNS: u32 = 200; const MIN_CONNS: u32 = 5; const TIMEOUT: u64 = 11; const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64); pub async fn get_db_pool() -> SqlitePool { let db_filename = { std::env::var("DATABASE_FILE").unwrap_or_else(|_| { #[cfg(not(test))] { let home = std::env::var("HOME").expect("Could not determine $HOME for finding db file"); format!("{home}/.witch-watch.db") } #[cfg(test)] { use rand_core::RngCore; let mut rng = rand_core::OsRng; let id = rng.next_u64(); // see https://www.sqlite.org/inmemorydb.html for meaning of the string; // it allows each separate test to have its own dedicated memory-backed db that // will live as long as the whole process format!("file:testdb-{id}?mode=memory&cache=shared") } }) }; tracing::info!("Connecting to DB at {db_filename}"); let conn_opts = SqliteConnectOptions::new() .foreign_keys(true) .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) .filename(&db_filename) .busy_timeout(Duration::from_secs(TIMEOUT)) .create_if_missing(true); let pool = SqlitePoolOptions::new() .max_connections(MAX_CONNS) .min_connections(MIN_CONNS) .idle_timeout(Some(Duration::from_secs(30))) .max_lifetime(Some(Duration::from_secs(3600))) .connect_with(conn_opts) .await .expect("can't connect to database"); // let the filesystem settle before trying anything // possibly not effective? tokio::time::sleep(Duration::from_millis(500)).await; { let mut m = Migrator::new(std::path::Path::new("./migrations")) .await .expect("Should be able to read the migration directory."); let m = m.set_locking(true); m.run(&pool) .await .expect("Should be able to run the migration."); tracing::info!("Ran migrations"); } pool } pub async fn session_layer(pool: SqlitePool, secret: &[u8]) -> SessionLayer { let store = session_store::SqliteSessionStore::from_client(pool); store .migrate() .await .expect("Calling `migrate()` should be reliable, is the DB gone?"); // since the secret is new every time the server starts, old sessions won't be // valid anymore; if there were ever more than one service host or there were // managed secrets, this would need to go away. store .clear_store() .await .unwrap_or_else(|e| tracing::error!("Could not delete old sessions; got error: {e}")); SessionLayer::new(store, secret) .with_secure(true) .with_session_ttl(Some(SESSION_TTL)) .with_persistence_policy(PersistencePolicy::ExistingOnly) } pub async fn auth_layer( pool: SqlitePool, secret: &[u8], ) -> AuthLayer, DbId, User> { const QUERY: &str = "select * from witches where id = $1"; let store = SqliteStore::::new(pool).with_query(QUERY); AuthLayer::new(store, secret) } //-************************************************************************ // Tests for `db` module. //-************************************************************************ #[cfg(test)] mod tests { #[tokio::test] async fn it_migrates_the_db() { let db = super::get_db_pool().await; let r = sqlx::query("select count(*) from witches") .fetch_one(&db) .await; assert!(r.is_ok()); } } //-************************************************************************ // End public interface. //-************************************************************************ //-************************************************************************ // Session store sub-module, not a public lib. //-************************************************************************ #[allow(dead_code)] mod session_store { use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session}; use sqlx::{pool::PoolConnection, Sqlite}; use super::*; // NOTE! This code was straight stolen from // https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs // and used under the terms of the MIT license: /* Copyright 2022 Jacob Rothstein Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /// sqlx sqlite session store for async-sessions #[derive(Clone, Debug)] pub struct SqliteSessionStore { client: SqlitePool, table_name: String, } impl SqliteSessionStore { /// constructs a new SqliteSessionStore from an existing /// sqlx::SqlitePool. the default table name for this session /// store will be "async_sessions". To override this, chain this /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name). pub fn from_client(client: SqlitePool) -> Self { Self { client, table_name: "async_sessions".into(), } } /// Constructs a new SqliteSessionStore from a sqlite: database url. /// note that this documentation uses the special `:memory:` /// sqlite database for convenient testing, but a real /// application would use a path like /// `sqlite:///path/to/database.db`. The default table name for /// this session store will be "async_sessions". To /// override this, either chain with /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or /// use /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) pub async fn new(database_url: &str) -> sqlx::Result { Ok(Self::from_client(SqlitePool::connect(database_url).await?)) } /// constructs a new SqliteSessionStore from a sqlite: database url. the /// default table name for this session store will be /// "async_sessions". To override this, either chain with /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or /// use /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) pub async fn new_with_table_name( database_url: &str, table_name: &str, ) -> sqlx::Result { Ok(Self::new(database_url).await?.with_table_name(table_name)) } /// Chainable method to add a custom table name. This will panic /// if the table name is not `[a-zA-Z0-9_-]+`. pub fn with_table_name(mut self, table_name: impl AsRef) -> Self { let table_name = table_name.as_ref(); if table_name.is_empty() || !table_name .chars() .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') { panic!( "table name must be [a-zA-Z0-9_-]+, but {} was not", table_name ); } self.table_name = table_name.to_owned(); self } /// Creates a session table if it does not already exist. If it /// does, this will noop, making it safe to call repeatedly on /// store initialization. In the future, this may make /// exactly-once modifications to the schema of the session table /// on breaking releases. pub async fn migrate(&self) -> sqlx::Result<()> { log::info!("migrating sessions on `{}`", self.table_name); let mut conn = self.client.acquire().await?; sqlx::query(&self.substitute_table_name( r#" CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% ( id TEXT PRIMARY KEY NOT NULL, expires INTEGER NULL, session TEXT NOT NULL ) "#, )) .execute(&mut conn) .await?; Ok(()) } // private utility function because sqlite does not support // parametrized table names fn substitute_table_name(&self, query: &str) -> String { query.replace("%%TABLE_NAME%%", &self.table_name) } /// retrieve a connection from the pool async fn connection(&self) -> sqlx::Result> { self.client.acquire().await } /// Performs a one-time cleanup task that clears out stale /// (expired) sessions. You may want to call this from cron. pub async fn cleanup(&self) -> sqlx::Result<()> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( r#" DELETE FROM %%TABLE_NAME%% WHERE expires < ? "#, )) .bind(Utc::now().timestamp()) .execute(&mut connection) .await?; Ok(()) } /// retrieves the number of sessions currently stored, including /// expired sessions pub async fn count(&self) -> sqlx::Result { let (count,) = sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%")) .fetch_one(&mut self.connection().await?) .await?; Ok(count) } } #[async_trait] impl SessionStore for SqliteSessionStore { async fn load_session(&self, cookie_value: String) -> Result> { let id = Session::id_from_cookie_value(&cookie_value)?; let mut connection = self.connection().await?; let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name( r#" SELECT session FROM %%TABLE_NAME%% WHERE id = ? AND (expires IS NULL OR expires > ?) "#, )) .bind(&id) .bind(Utc::now().timestamp()) .fetch_optional(&mut connection) .await?; Ok(result .map(|(session,)| serde_json::from_str(&session)) .transpose()?) } async fn store_session(&self, session: Session) -> Result> { let id = session.id(); let string = serde_json::to_string(&session)?; let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( r#" INSERT INTO %%TABLE_NAME%% (id, session, expires) VALUES (?, ?, ?) ON CONFLICT(id) DO UPDATE SET expires = excluded.expires, session = excluded.session "#, )) .bind(id) .bind(&string) .bind(session.expiry().map(|expiry| expiry.timestamp())) .execute(&mut connection) .await?; Ok(session.into_cookie_value()) } async fn destroy_session(&self, session: Session) -> Result { let id = session.id(); let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( r#" DELETE FROM %%TABLE_NAME%% WHERE id = ? "#, )) .bind(id) .execute(&mut connection) .await?; Ok(()) } async fn clear_store(&self) -> Result { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( r#" DELETE FROM %%TABLE_NAME%% "#, )) .execute(&mut connection) .await?; Ok(()) } } #[cfg(test)] mod tests { use std::time::Duration; use super::*; async fn test_store() -> SqliteSessionStore { let store = SqliteSessionStore::new("sqlite::memory:") .await .expect("building a sqlite :memory: SqliteSessionStore"); store .migrate() .await .expect("migrating a brand new :memory: SqliteSessionStore"); store } #[tokio::test] async fn creating_a_new_session_with_no_expiry() -> Result { let store = test_store().await; let mut session = Session::new(); session.insert("key", "value")?; let cloned = session.clone(); let cookie_value = store.store_session(session).await?.unwrap(); let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, count(*) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); assert_eq!(id, cloned.id()); assert_eq!(expires, None); let deserialized_session: Session = serde_json::from_str(&serialized)?; assert_eq!(cloned.id(), deserialized_session.id()); assert_eq!("value", &deserialized_session.get::("key").unwrap()); let loaded_session = store.load_session(cookie_value).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); assert!(!loaded_session.is_expired()); Ok(()) } #[tokio::test] async fn updating_a_session() -> Result { let store = test_store().await; let mut session = Session::new(); let original_id = session.id().to_owned(); session.insert("key", "value")?; let cookie_value = store.store_session(session).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); session.insert("key", "other value")?; assert_eq!(None, store.store_session(session).await?); let session = store.load_session(cookie_value.clone()).await?.unwrap(); assert_eq!(session.get::("key").unwrap(), "other value"); let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); assert_eq!(original_id, id); Ok(()) } #[tokio::test] async fn updating_a_session_extending_expiry() -> Result { let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(10)); let original_id = session.id().to_owned(); let original_expires = session.expiry().unwrap().clone(); let cookie_value = store.store_session(session).await?.unwrap(); let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); assert_eq!(session.expiry().unwrap(), &original_expires); session.expire_in(Duration::from_secs(20)); let new_expires = session.expiry().unwrap().clone(); store.store_session(session).await?; let session = store.load_session(cookie_value.clone()).await?.unwrap(); assert_eq!(session.expiry().unwrap(), &new_expires); let (id, expires, count): (String, i64, i64) = sqlx::query_as("select id, expires, count(*) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); assert_eq!(expires, new_expires.timestamp()); assert_eq!(original_id, id); Ok(()) } #[tokio::test] async fn creating_a_new_session_with_expiry() -> Result { let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(1)); session.insert("key", "value")?; let cloned = session.clone(); let cookie_value = store.store_session(session).await?.unwrap(); let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, count(*) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); assert_eq!(id, cloned.id()); assert!(expires.unwrap() > Utc::now().timestamp()); let deserialized_session: Session = serde_json::from_str(&serialized)?; assert_eq!(cloned.id(), deserialized_session.id()); assert_eq!("value", &deserialized_session.get::("key").unwrap()); let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); assert!(!loaded_session.is_expired()); tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(None, store.load_session(cookie_value).await?); Ok(()) } #[tokio::test] async fn destroying_a_single_session() -> Result { let store = test_store().await; for _ in 0..3i8 { store.store_session(Session::new()).await?; } let cookie = store.store_session(Session::new()).await?.unwrap(); assert_eq!(4, store.count().await?); let session = store.load_session(cookie.clone()).await?.unwrap(); store.destroy_session(session.clone()).await.unwrap(); assert_eq!(None, store.load_session(cookie).await?); assert_eq!(3, store.count().await?); // // attempting to destroy the session again is not an error // assert!(store.destroy_session(session).await.is_ok()); Ok(()) } #[tokio::test] async fn clearing_the_whole_store() -> Result { let store = test_store().await; for _ in 0..3i8 { store.store_session(Session::new()).await?; } assert_eq!(3, store.count().await?); store.clear_store().await.unwrap(); assert_eq!(0, store.count().await?); Ok(()) } } }