No actual tests added.

This commit is contained in:
Joe Ardent 2023-05-30 18:46:32 -07:00
commit ac0af7970d
6 changed files with 607 additions and 542 deletions

560
src/db.rs
View File

@ -1,12 +1,22 @@
use std::time::Duration; use std::time::Duration;
use async_session::SessionStore;
use axum_login::{
axum_sessions::{PersistencePolicy, SessionLayer},
AuthLayer, SqliteStore, SqlxStore,
};
use session_store::SqliteSessionStore;
use sqlx::{ use sqlx::{
sqlite::{SqliteConnectOptions, SqlitePoolOptions}, sqlite::{SqliteConnectOptions, SqlitePoolOptions},
SqlitePool, SqlitePool,
}; };
use uuid::Uuid;
use crate::User;
const MAX_CONNS: u32 = 100; const MAX_CONNS: u32 = 100;
const TIMEOUT: u64 = 5; const TIMEOUT: u64 = 5;
const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64);
pub async fn get_pool() -> SqlitePool { pub async fn get_pool() -> SqlitePool {
let db_filename = { let db_filename = {
@ -30,3 +40,553 @@ pub async fn get_pool() -> SqlitePool {
.await .await
.expect("can't connect to database") .expect("can't connect to database")
} }
pub async fn session_layer(pool: SqlitePool, secret: &[u8]) -> SessionLayer<SqliteSessionStore> {
let store = session_store::SqliteSessionStore::from_client(pool);
store
.migrate()
.await
.expect("Calling `migrate()` should be reliable, is the DB gone?");
// since the secret is new every time the server starts, old sessions won't be
// valid anymore; if there were ever more than one service host or there were
// managed secrets, this would need to go away.
store
.clear_store()
.await
.unwrap_or_else(|e| tracing::error!("Could not delete old sessions; got error: {e}"));
SessionLayer::new(store, secret)
.with_secure(true)
.with_session_ttl(Some(SESSION_TTL))
.with_persistence_policy(PersistencePolicy::ExistingOnly)
}
pub async fn auth_layer(
pool: SqlitePool,
secret: &[u8],
) -> AuthLayer<SqlxStore<SqlitePool, User>, Uuid, User> {
const QUERY: &str = "select * from witches where id = $1";
let store = SqliteStore::<User>::new(pool).with_query(QUERY);
AuthLayer::new(store, secret)
}
//-************************************************************************
// Session store sub-module, not a public lib.
//-************************************************************************
mod session_store {
use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session};
use sqlx::{pool::PoolConnection, Sqlite};
use super::*;
// NOTE! This code was straight stolen from
// https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs
// and used under the terms of the MIT license:
/*
Copyright 2022 Jacob Rothstein
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the Software), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.
THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
/// sqlx sqlite session store for async-sessions
///
/// ```rust
/// use witch_watch::session_store::SqliteSessionStore;
/// use async_session::{Session, SessionStore, Result};
/// use std::time::Duration;
///
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
///
/// let mut session = Session::new();
/// session.insert("key", vec![1,2,3]);
///
/// let cookie_value = store.store_session(session).await?.unwrap();
/// let session = store.load_session(cookie_value).await?.unwrap();
/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
/// # Ok(()) }
#[derive(Clone, Debug)]
pub struct SqliteSessionStore {
client: SqlitePool,
table_name: String,
}
impl SqliteSessionStore {
/// constructs a new SqliteSessionStore from an existing
/// sqlx::SqlitePool. the default table name for this session
/// store will be "async_sessions". To override this, chain this
/// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
/// let store = SqliteSessionStore::from_client(pool)
/// .with_table_name("custom_table_name");
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub fn from_client(client: SqlitePool) -> Self {
Self {
client,
table_name: "async_sessions".into(),
}
}
/// Constructs a new SqliteSessionStore from a sqlite: database url.
/// note that this documentation uses the special `:memory:`
/// sqlite database for convenient testing, but a real
/// application would use a path like
/// `sqlite:///path/to/database.db`. The default table name for
/// this session store will be "async_sessions". To
/// override this, either chain with
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
/// use
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub async fn new(database_url: &str) -> sqlx::Result<Self> {
Ok(Self::from_client(SqlitePool::connect(database_url).await?))
}
/// constructs a new SqliteSessionStore from a sqlite: database url. the
/// default table name for this session store will be
/// "async_sessions". To override this, either chain with
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
/// use
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?;
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub async fn new_with_table_name(
database_url: &str,
table_name: &str,
) -> sqlx::Result<Self> {
Ok(Self::new(database_url).await?.with_table_name(table_name))
}
/// Chainable method to add a custom table name. This will panic
/// if the table name is not `[a-zA-Z0-9_-]+`.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
/// .with_table_name("custom_name");
/// store.migrate().await;
/// # Ok(()) }
/// ```
///
/// ```should_panic
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
/// .with_table_name("johnny (); drop users;");
/// # Ok(()) }
/// ```
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
let table_name = table_name.as_ref();
if table_name.is_empty()
|| !table_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
panic!(
"table name must be [a-zA-Z0-9_-]+, but {} was not",
table_name
);
}
self.table_name = table_name.to_owned();
self
}
/// Creates a session table if it does not already exist. If it
/// does, this will noop, making it safe to call repeatedly on
/// store initialization. In the future, this may make
/// exactly-once modifications to the schema of the session table
/// on breaking releases.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{Result, SessionStore, Session};
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// assert!(store.count().await.is_err());
/// store.migrate().await?;
/// store.store_session(Session::new()).await?;
/// store.migrate().await?; // calling it a second time is safe
/// assert_eq!(store.count().await?, 1);
/// # Ok(()) }
/// ```
pub async fn migrate(&self) -> sqlx::Result<()> {
log::info!("migrating sessions on `{}`", self.table_name);
let mut conn = self.client.acquire().await?;
sqlx::query(&self.substitute_table_name(
r#"
CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
id TEXT PRIMARY KEY NOT NULL,
expires INTEGER NULL,
session TEXT NOT NULL
)
"#,
))
.execute(&mut conn)
.await?;
Ok(())
}
// private utility function because sqlite does not support
// parametrized table names
fn substitute_table_name(&self, query: &str) -> String {
query.replace("%%TABLE_NAME%%", &self.table_name)
}
/// retrieve a connection from the pool
async fn connection(&self) -> sqlx::Result<PoolConnection<Sqlite>> {
self.client.acquire().await
}
/// Performs a one-time cleanup task that clears out stale
/// (expired) sessions. You may want to call this from cron.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
/// let mut session = Session::new();
/// session.set_expiry(Utc::now() - Duration::seconds(5));
/// store.store_session(session).await?;
/// assert_eq!(store.count().await?, 1);
/// store.cleanup().await?;
/// assert_eq!(store.count().await?, 0);
/// # Ok(()) }
/// ```
pub async fn cleanup(&self) -> sqlx::Result<()> {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%%
WHERE expires < ?
"#,
))
.bind(Utc::now().timestamp())
.execute(&mut connection)
.await?;
Ok(())
}
/// retrieves the number of sessions currently stored, including
/// expired sessions
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{Result, SessionStore, Session};
/// # use std::time::Duration;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
/// assert_eq!(store.count().await?, 0);
/// store.store_session(Session::new()).await?;
/// assert_eq!(store.count().await?, 1);
/// # Ok(()) }
/// ```
pub async fn count(&self) -> sqlx::Result<i32> {
let (count,) =
sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
.fetch_one(&mut self.connection().await?)
.await?;
Ok(count)
}
}
#[async_trait]
impl SessionStore for SqliteSessionStore {
async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
let id = Session::id_from_cookie_value(&cookie_value)?;
let mut connection = self.connection().await?;
let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
r#"
SELECT session FROM %%TABLE_NAME%%
WHERE id = ? AND (expires IS NULL OR expires > ?)
"#,
))
.bind(&id)
.bind(Utc::now().timestamp())
.fetch_optional(&mut connection)
.await?;
Ok(result
.map(|(session,)| serde_json::from_str(&session))
.transpose()?)
}
async fn store_session(&self, session: Session) -> Result<Option<String>> {
let id = session.id();
let string = serde_json::to_string(&session)?;
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
INSERT INTO %%TABLE_NAME%%
(id, session, expires) VALUES (?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
expires = excluded.expires,
session = excluded.session
"#,
))
.bind(id)
.bind(&string)
.bind(session.expiry().map(|expiry| expiry.timestamp()))
.execute(&mut connection)
.await?;
Ok(session.into_cookie_value())
}
async fn destroy_session(&self, session: Session) -> Result {
let id = session.id();
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%% WHERE id = ?
"#,
))
.bind(id)
.execute(&mut connection)
.await?;
Ok(())
}
async fn clear_store(&self) -> Result {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%%
"#,
))
.execute(&mut connection)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
async fn test_store() -> SqliteSessionStore {
let store = SqliteSessionStore::new("sqlite::memory:")
.await
.expect("building a sqlite :memory: SqliteSessionStore");
store
.migrate()
.await
.expect("migrating a brand new :memory: SqliteSessionStore");
store
}
#[tokio::test]
async fn creating_a_new_session_with_no_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert_eq!(expires, None);
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
Ok(())
}
#[tokio::test]
async fn updating_a_session() -> Result {
let store = test_store().await;
let mut session = Session::new();
let original_id = session.id().to_owned();
session.insert("key", "value")?;
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
session.insert("key", "other value")?;
assert_eq!(None, store.store_session(session).await?);
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.get::<String>("key").unwrap(), "other value");
let (id, count): (String, i64) =
sqlx::query_as("select id, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(original_id, id);
Ok(())
}
#[tokio::test]
async fn updating_a_session_extending_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone();
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &original_expires);
session.expire_in(Duration::from_secs(20));
let new_expires = session.expiry().unwrap().clone();
store.store_session(session).await?;
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &new_expires);
let (id, expires, count): (String, i64, i64) =
sqlx::query_as("select id, expires, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(expires, new_expires.timestamp());
assert_eq!(original_id, id);
Ok(())
}
#[tokio::test]
async fn creating_a_new_session_with_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert!(expires.unwrap() > Utc::now().timestamp());
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(None, store.load_session(cookie_value).await?);
Ok(())
}
#[tokio::test]
async fn destroying_a_single_session() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
let cookie = store.store_session(Session::new()).await?.unwrap();
assert_eq!(4, store.count().await?);
let session = store.load_session(cookie.clone()).await?.unwrap();
store.destroy_session(session.clone()).await.unwrap();
assert_eq!(None, store.load_session(cookie).await?);
assert_eq!(3, store.count().await?);
// // attempting to destroy the session again is not an error
// assert!(store.destroy_session(session).await.is_ok());
Ok(())
}
#[tokio::test]
async fn clearing_the_whole_store() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
assert_eq!(3, store.count().await?);
store.clear_store().await.unwrap();
assert_eq!(0, store.count().await?);
Ok(())
}
}
}

View File

@ -8,7 +8,6 @@ use uuid::Uuid;
pub mod db; pub mod db;
pub mod generic_handlers; pub mod generic_handlers;
pub mod login; pub mod login;
pub mod session_store;
pub mod signup; pub mod signup;
pub(crate) mod templates; pub(crate) mod templates;
pub mod users; pub mod users;

View File

@ -20,8 +20,6 @@ use crate::{
// Constants // Constants
//-************************************************************************ //-************************************************************************
const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1";
//-************************************************************************ //-************************************************************************
// Login error and success types // Login error and success types
//-************************************************************************ //-************************************************************************
@ -68,7 +66,7 @@ pub async fn post_login(
let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?; let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?;
let pw = pw.trim(); let pw = pw.trim();
let user = User::get(username, &pool) let user = User::try_get(username, &pool)
.await .await
.map_err(|_| LoginErrorKind::Unknown)?; .map_err(|_| LoginErrorKind::Unknown)?;
@ -81,13 +79,6 @@ pub async fn post_login(
.await .await
.map_err(|_| LoginErrorKind::Internal)?; .map_err(|_| LoginErrorKind::Internal)?;
// update last_seen; maybe this is ok to fail?
sqlx::query(LAST_SEEN_QUERY)
.bind(user.id)
.execute(&pool)
.await
.map_err(|_| LoginErrorKind::Internal)?;
Ok(Redirect::temporary("/")) Ok(Redirect::temporary("/"))
} }
_ => Err(LoginErrorKind::BadPassword.into()), _ => Err(LoginErrorKind::BadPassword.into()),

View File

@ -1,19 +1,14 @@
use std::{net::SocketAddr, time::Duration}; use std::net::SocketAddr;
use axum::{routing::get, Router}; use axum::{middleware, routing::get, Router};
use axum_login::{
axum_sessions::{PersistencePolicy, SessionLayer},
AuthLayer, SqliteStore,
};
use rand_core::{OsRng, RngCore}; use rand_core::{OsRng, RngCore};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use witch_watch::{ use witch_watch::{
db, db,
generic_handlers::{handle_slash, handle_slash_redir}, generic_handlers::{handle_slash, handle_slash_redir},
login::{get_login, get_logout, post_login, post_logout}, login::{get_login, get_logout, post_login, post_logout},
session_store::SqliteSessionStore,
signup::{get_create_user, handle_signup_success, post_create_user}, signup::{get_create_user, handle_signup_success, post_create_user},
User, users,
}; };
#[tokio::main] #[tokio::main]
@ -35,20 +30,8 @@ async fn main() {
bytes bytes
}; };
let session_layer = { let session_layer = db::session_layer(pool.clone(), &secret).await;
let store = SqliteSessionStore::from_client(pool.clone()); let auth_layer = db::auth_layer(pool.clone(), &secret).await;
store.migrate().await.expect("Could not migrate session DB");
SessionLayer::new(store, &secret)
.with_secure(true)
.with_persistence_policy(PersistencePolicy::ExistingOnly)
.with_session_ttl(Some(Duration::from_secs(3600 * 24 * 366)))
};
let auth_layer = {
const QUERY: &str = "select * from witches where id = $1";
let store = SqliteStore::<User>::new(pool.clone()).with_query(QUERY);
AuthLayer::new(store, &secret)
};
let app = Router::new() let app = Router::new()
.route("/", get(handle_slash).post(handle_slash)) .route("/", get(handle_slash).post(handle_slash))
@ -60,12 +43,17 @@ async fn main() {
.route("/login", get(get_login).post(post_login)) .route("/login", get(get_login).post(post_login))
.route("/logout", get(get_logout).post(post_logout)) .route("/logout", get(get_logout).post(post_logout))
.fallback(handle_slash_redir) .fallback(handle_slash_redir)
.route_layer(middleware::from_fn_with_state(
pool.clone(),
users::handle_update_last_seen,
))
.layer(auth_layer) .layer(auth_layer)
.layer(session_layer) .layer(session_layer)
.with_state(pool); .with_state(pool);
let addr = ([127, 0, 0, 1], 3000);
tracing::debug!("binding to 0.0.0.0:3000"); tracing::debug!("binding to 0.0.0.0:3000");
axum::Server::bind(&SocketAddr::from(([0, 0, 0, 0], 3000))) axum::Server::bind(&SocketAddr::from(addr))
.serve(app.into_make_service()) .serve(app.into_make_service())
.await .await
.unwrap(); .unwrap();

View File

@ -1,507 +0,0 @@
use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite};
// NOTE! This code was straight stolen from
// https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs
// and used under the terms of the MIT license:
/*
Copyright 2022 Jacob Rothstein
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the Software), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.
THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
/// sqlx sqlite session store for async-sessions
///
/// ```rust
/// use witch_watch::session_store::SqliteSessionStore;
/// use async_session::{Session, SessionStore, Result};
/// use std::time::Duration;
///
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
///
/// let mut session = Session::new();
/// session.insert("key", vec![1,2,3]);
///
/// let cookie_value = store.store_session(session).await?.unwrap();
/// let session = store.load_session(cookie_value).await?.unwrap();
/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
/// # Ok(()) }
#[derive(Clone, Debug)]
pub struct SqliteSessionStore {
client: SqlitePool,
table_name: String,
}
impl SqliteSessionStore {
/// constructs a new SqliteSessionStore from an existing
/// sqlx::SqlitePool. the default table name for this session
/// store will be "async_sessions". To override this, chain this
/// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
/// let store = SqliteSessionStore::from_client(pool)
/// .with_table_name("custom_table_name");
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub fn from_client(client: SqlitePool) -> Self {
Self {
client,
table_name: "async_sessions".into(),
}
}
/// Constructs a new SqliteSessionStore from a sqlite: database url. note
/// that this documentation uses the special `:memory:` sqlite
/// database for convenient testing, but a real application would
/// use a path like `sqlite:///path/to/database.db`. The default
/// table name for this session store will be "async_sessions". To
/// override this, either chain with
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
/// use
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub async fn new(database_url: &str) -> sqlx::Result<Self> {
Ok(Self::from_client(SqlitePool::connect(database_url).await?))
}
/// constructs a new SqliteSessionStore from a sqlite: database url. the
/// default table name for this session store will be
/// "async_sessions". To override this, either chain with
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
/// use
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?;
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
Ok(Self::new(database_url).await?.with_table_name(table_name))
}
/// Chainable method to add a custom table name. This will panic
/// if the table name is not `[a-zA-Z0-9_-]+`.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
/// .with_table_name("custom_name");
/// store.migrate().await;
/// # Ok(()) }
/// ```
///
/// ```should_panic
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
/// .with_table_name("johnny (); drop users;");
/// # Ok(()) }
/// ```
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
let table_name = table_name.as_ref();
if table_name.is_empty()
|| !table_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
panic!(
"table name must be [a-zA-Z0-9_-]+, but {} was not",
table_name
);
}
self.table_name = table_name.to_owned();
self
}
/// Creates a session table if it does not already exist. If it
/// does, this will noop, making it safe to call repeatedly on
/// store initialization. In the future, this may make
/// exactly-once modifications to the schema of the session table
/// on breaking releases.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{Result, SessionStore, Session};
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// assert!(store.count().await.is_err());
/// store.migrate().await?;
/// store.store_session(Session::new()).await?;
/// store.migrate().await?; // calling it a second time is safe
/// assert_eq!(store.count().await?, 1);
/// # Ok(()) }
/// ```
pub async fn migrate(&self) -> sqlx::Result<()> {
log::info!("migrating sessions on `{}`", self.table_name);
let mut conn = self.client.acquire().await?;
sqlx::query(&self.substitute_table_name(
r#"
CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
id TEXT PRIMARY KEY NOT NULL,
expires INTEGER NULL,
session TEXT NOT NULL
)
"#,
))
.execute(&mut conn)
.await?;
Ok(())
}
// private utility function because sqlite does not support
// parametrized table names
fn substitute_table_name(&self, query: &str) -> String {
query.replace("%%TABLE_NAME%%", &self.table_name)
}
/// retrieve a connection from the pool
async fn connection(&self) -> sqlx::Result<PoolConnection<Sqlite>> {
self.client.acquire().await
}
/// Performs a one-time cleanup task that clears out stale
/// (expired) sessions. You may want to call this from cron.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
/// let mut session = Session::new();
/// session.set_expiry(Utc::now() - Duration::seconds(5));
/// store.store_session(session).await?;
/// assert_eq!(store.count().await?, 1);
/// store.cleanup().await?;
/// assert_eq!(store.count().await?, 0);
/// # Ok(()) }
/// ```
pub async fn cleanup(&self) -> sqlx::Result<()> {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%%
WHERE expires < ?
"#,
))
.bind(Utc::now().timestamp())
.execute(&mut connection)
.await?;
Ok(())
}
/// retrieves the number of sessions currently stored, including
/// expired sessions
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{Result, SessionStore, Session};
/// # use std::time::Duration;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
/// assert_eq!(store.count().await?, 0);
/// store.store_session(Session::new()).await?;
/// assert_eq!(store.count().await?, 1);
/// # Ok(()) }
/// ```
pub async fn count(&self) -> sqlx::Result<i32> {
let (count,) =
sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
.fetch_one(&mut self.connection().await?)
.await?;
Ok(count)
}
}
#[async_trait]
impl SessionStore for SqliteSessionStore {
async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
let id = Session::id_from_cookie_value(&cookie_value)?;
let mut connection = self.connection().await?;
let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
r#"
SELECT session FROM %%TABLE_NAME%%
WHERE id = ? AND (expires IS NULL OR expires > ?)
"#,
))
.bind(&id)
.bind(Utc::now().timestamp())
.fetch_optional(&mut connection)
.await?;
Ok(result
.map(|(session,)| serde_json::from_str(&session))
.transpose()?)
}
async fn store_session(&self, session: Session) -> Result<Option<String>> {
let id = session.id();
let string = serde_json::to_string(&session)?;
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
INSERT INTO %%TABLE_NAME%%
(id, session, expires) VALUES (?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
expires = excluded.expires,
session = excluded.session
"#,
))
.bind(id)
.bind(&string)
.bind(session.expiry().map(|expiry| expiry.timestamp()))
.execute(&mut connection)
.await?;
Ok(session.into_cookie_value())
}
async fn destroy_session(&self, session: Session) -> Result {
let id = session.id();
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%% WHERE id = ?
"#,
))
.bind(id)
.execute(&mut connection)
.await?;
Ok(())
}
async fn clear_store(&self) -> Result {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%%
"#,
))
.execute(&mut connection)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
async fn test_store() -> SqliteSessionStore {
let store = SqliteSessionStore::new("sqlite::memory:")
.await
.expect("building a sqlite :memory: SqliteSessionStore");
store
.migrate()
.await
.expect("migrating a brand new :memory: SqliteSessionStore");
store
}
#[tokio::test]
async fn creating_a_new_session_with_no_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert_eq!(expires, None);
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
Ok(())
}
#[tokio::test]
async fn updating_a_session() -> Result {
let store = test_store().await;
let mut session = Session::new();
let original_id = session.id().to_owned();
session.insert("key", "value")?;
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
session.insert("key", "other value")?;
assert_eq!(None, store.store_session(session).await?);
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.get::<String>("key").unwrap(), "other value");
let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(original_id, id);
Ok(())
}
#[tokio::test]
async fn updating_a_session_extending_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone();
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &original_expires);
session.expire_in(Duration::from_secs(20));
let new_expires = session.expiry().unwrap().clone();
store.store_session(session).await?;
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &new_expires);
let (id, expires, count): (String, i64, i64) =
sqlx::query_as("select id, expires, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(expires, new_expires.timestamp());
assert_eq!(original_id, id);
Ok(())
}
#[tokio::test]
async fn creating_a_new_session_with_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert!(expires.unwrap() > Utc::now().timestamp());
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(None, store.load_session(cookie_value).await?);
Ok(())
}
#[tokio::test]
async fn destroying_a_single_session() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
let cookie = store.store_session(Session::new()).await?.unwrap();
assert_eq!(4, store.count().await?);
let session = store.load_session(cookie.clone()).await?.unwrap();
store.destroy_session(session.clone()).await.unwrap();
assert_eq!(None, store.load_session(cookie).await?);
assert_eq!(3, store.count().await?);
// // attempting to destroy the session again is not an error
// assert!(store.destroy_session(session).await.is_ok());
Ok(())
}
#[tokio::test]
async fn clearing_the_whole_store() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
assert_eq!(3, store.count().await?);
store.clear_store().await.unwrap();
assert_eq!(0, store.count().await?);
Ok(())
}
}

View File

@ -1,11 +1,15 @@
use std::fmt::Display; use std::fmt::Display;
use axum::{extract::State, http::Request, middleware::Next, response::IntoResponse};
use axum_login::{secrecy::SecretVec, AuthUser}; use axum_login::{secrecy::SecretVec, AuthUser};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use uuid::Uuid; use uuid::Uuid;
use crate::AuthContext;
const USERNAME_QUERY: &str = "select * from witches where username = $1"; const USERNAME_QUERY: &str = "select * from witches where username = $1";
const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1";
#[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)] #[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)]
pub struct User { pub struct User {
@ -41,10 +45,40 @@ impl AuthUser<Uuid> for User {
} }
impl User { impl User {
pub async fn get(username: &str, db: &SqlitePool) -> Result<User, impl std::error::Error> { pub async fn try_get(username: &str, db: &SqlitePool) -> Result<User, impl std::error::Error> {
sqlx::query_as(USERNAME_QUERY) sqlx::query_as(USERNAME_QUERY)
.bind(username) .bind(username)
.fetch_one(db) .fetch_one(db)
.await .await
} }
pub async fn update_last_seen(&self, pool: &SqlitePool) {
match sqlx::query(LAST_SEEN_QUERY)
.bind(self.id)
.execute(pool)
.await
{
Ok(_) => {}
Err(e) => {
let id = self.id.as_simple();
tracing::error!("Could not update last_seen for user {id}; got {e:?}");
}
}
}
}
//-************************************************************************
// User-specific middleware
//-************************************************************************
pub async fn handle_update_last_seen<BodyT>(
State(pool): State<SqlitePool>,
auth: AuthContext,
request: Request<BodyT>,
next: Next<BodyT>,
) -> impl IntoResponse {
if let Some(user) = auth.current_user {
user.update_last_seen(&pool).await;
}
next.run(request).await
} }