From f50abaa4a6bc34f84d5bbe62cf2d144a66992516 Mon Sep 17 00:00:00 2001
From: Joe Ardent
Date: Mon, 22 May 2023 16:57:08 -0700
Subject: [PATCH 01/12] re-org some handlers, handle '/'.
---
src/generic_handlers.rs | 7 ++++++
src/handlers.rs | 55 -----------------------------------------
src/lib.rs | 2 +-
src/main.rs | 3 +++
src/users.rs | 8 +++---
5 files changed, 15 insertions(+), 60 deletions(-)
create mode 100644 src/generic_handlers.rs
delete mode 100644 src/handlers.rs
diff --git a/src/generic_handlers.rs b/src/generic_handlers.rs
new file mode 100644
index 0000000..7ad4bcd
--- /dev/null
+++ b/src/generic_handlers.rs
@@ -0,0 +1,7 @@
+use axum::response::{IntoResponse, Redirect};
+
+pub async fn handle_slash_redir() -> impl IntoResponse {
+ Redirect::temporary("/")
+}
+
+pub async fn handle_slash() -> impl IntoResponse {}
diff --git a/src/handlers.rs b/src/handlers.rs
deleted file mode 100644
index 5ba9bd4..0000000
--- a/src/handlers.rs
+++ /dev/null
@@ -1,55 +0,0 @@
-use axum::{
- async_trait,
- extract::{FromRef, FromRequestParts, State},
- http::{request::Parts, StatusCode},
-};
-use sqlx::SqlitePool;
-
-pub async fn using_connection_pool_extractor(
- State(pool): State,
-) -> Result {
- sqlx::query_scalar("select 'hello world from sqlite get'")
- .fetch_one(&pool)
- .await
- .map_err(internal_error)
-}
-
-// we can also write a custom extractor that grabs a connection from the pool
-// which setup is appropriate depends on your application
-pub struct DatabaseConnection(sqlx::pool::PoolConnection);
-
-#[async_trait]
-impl FromRequestParts for DatabaseConnection
-where
- SqlitePool: FromRef,
- S: Send + Sync,
-{
- type Rejection = (StatusCode, String);
-
- async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result {
- let pool = SqlitePool::from_ref(state);
-
- let conn = pool.acquire().await.map_err(internal_error)?;
-
- Ok(Self(conn))
- }
-}
-
-pub async fn using_connection_extractor(
- DatabaseConnection(conn): DatabaseConnection,
-) -> Result {
- let mut conn = conn;
- sqlx::query_scalar("select 'hello world from sqlite post'")
- .fetch_one(&mut conn)
- .await
- .map_err(internal_error)
-}
-
-/// Utility function for mapping any error into a `500 Internal Server Error`
-/// response.
-fn internal_error(err: E) -> (StatusCode, String)
-where
- E: std::error::Error,
-{
- (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
-}
diff --git a/src/lib.rs b/src/lib.rs
index f97526d..bf90918 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,6 +2,6 @@
extern crate justerror;
pub mod db;
-pub mod handlers;
+pub mod generic_handlers;
pub(crate) mod templates;
pub mod users;
diff --git a/src/main.rs b/src/main.rs
index 51e79c6..f99d315 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -4,6 +4,7 @@ use axum::{routing::get, Router};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use witch_watch::{
db,
+ generic_handlers::{handle_slash, handle_slash_redir},
users::{get_create_user, handle_signup_success, post_create_user},
};
@@ -21,11 +22,13 @@ async fn main() {
// build our application with some routes
let app = Router::new()
+ .route("/", get(handle_slash).post(handle_slash))
.route("/signup", get(get_create_user).post(post_create_user))
.route(
"/signup_success/:id",
get(handle_signup_success).post(handle_signup_success),
)
+ .fallback(handle_slash_redir)
.with_state(pool);
tracing::debug!("binding to 0.0.0.0:3000");
diff --git a/src/users.rs b/src/users.rs
index a140739..e6f067d 100644
--- a/src/users.rs
+++ b/src/users.rs
@@ -145,8 +145,8 @@ pub async fn handle_signup_success(
State(pool): State,
) -> Response {
let user: User = {
- let id = id;
- let id = Uuid::try_parse(&id).unwrap_or_default();
+ let id = id.trim();
+ let id = Uuid::try_parse(id).unwrap_or_default();
let id_bytes = id.to_bytes_le();
sqlx::query_as(ID_QUERY)
.bind(id_bytes.as_slice())
@@ -157,8 +157,8 @@ pub async fn handle_signup_success(
let mut resp = CreateUserSuccess(user.clone()).into_response();
- if user.username.is_empty() {
- // redirect to front page if we got here without a valid witch header
+ if user.username.is_empty() || id.is_empty() {
+ // redirect to front page if we got here without a valid witch ID
*resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;
resp.headers_mut().insert("Location", "/".parse().unwrap());
}
From 8237715066fd92943cd7e73de15290bb7ae2e490 Mon Sep 17 00:00:00 2001
From: Joe Ardent
Date: Mon, 22 May 2023 17:18:13 -0700
Subject: [PATCH 02/12] simplify uuid/db handling
---
src/lib.rs | 10 ++++++++++
src/users.rs | 9 +++------
2 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/src/lib.rs b/src/lib.rs
index bf90918..1f45553 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,3 +5,13 @@ pub mod db;
pub mod generic_handlers;
pub(crate) mod templates;
pub mod users;
+
+pub trait ToBlob {
+ fn blob(&self) -> &[u8];
+}
+
+impl ToBlob for uuid::Uuid {
+ fn blob(&self) -> &[u8] {
+ self.as_bytes().as_slice()
+ }
+}
diff --git a/src/users.rs b/src/users.rs
index e6f067d..fa6fed1 100644
--- a/src/users.rs
+++ b/src/users.rs
@@ -14,7 +14,7 @@ use sqlx::{sqlite::SqliteRow, Row, SqlitePool};
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid;
-use crate::templates::CreateUser;
+use crate::{templates::CreateUser, ToBlob};
const CREATE_QUERY: &str =
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
@@ -147,9 +147,8 @@ pub async fn handle_signup_success(
let user: User = {
let id = id.trim();
let id = Uuid::try_parse(id).unwrap_or_default();
- let id_bytes = id.to_bytes_le();
sqlx::query_as(ID_QUERY)
- .bind(id_bytes.as_slice())
+ .bind(id.blob())
.fetch_one(&pool)
.await
.unwrap_or_default()
@@ -181,10 +180,8 @@ async fn create_user(
.to_string();
let id = Uuid::new_v4();
- let id_bytes = id.to_bytes_le();
- let id_bytes = id_bytes.as_slice();
let res = sqlx::query(CREATE_QUERY)
- .bind(id_bytes)
+ .bind(id.blob())
.bind(username)
.bind(displayname)
.bind(email)
From 0d6c9932d6d41478865bc861e2218b931767a546 Mon Sep 17 00:00:00 2001
From: Joe Ardent
Date: Wed, 24 May 2023 16:39:13 -0700
Subject: [PATCH 03/12] Add SqliteSessionStore module.
Stolen with very minor mods from https://github.com/jbr/async-sqlx-session
---
src/main.rs | 13 ++
src/session_store.rs | 507 +++++++++++++++++++++++++++++++++++++++++++
2 files changed, 520 insertions(+)
create mode 100644 src/session_store.rs
diff --git a/src/main.rs b/src/main.rs
index f99d315..5e83589 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,10 +1,13 @@
use std::net::SocketAddr;
use axum::{routing::get, Router};
+use axum_login::axum_sessions::SessionLayer;
+use rand_core::{OsRng, RngCore};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use witch_watch::{
db,
generic_handlers::{handle_slash, handle_slash_redir},
+ session_store::SqliteSessionStore,
users::{get_create_user, handle_signup_success, post_create_user},
};
@@ -19,6 +22,15 @@ async fn main() {
.init();
let pool = db::get_pool().await;
+ let store = SqliteSessionStore::from_client(pool.clone());
+ store.migrate().await.expect("Could not migrate session DB");
+ let secret = {
+ let mut bytes = [0u8; 128];
+ let mut rng = OsRng;
+ rng.fill_bytes(&mut bytes);
+ bytes
+ };
+ let session_layer = SessionLayer::new(store, &secret).with_secure(true);
// build our application with some routes
let app = Router::new()
@@ -29,6 +41,7 @@ async fn main() {
get(handle_signup_success).post(handle_signup_success),
)
.fallback(handle_slash_redir)
+ .layer(session_layer)
.with_state(pool);
tracing::debug!("binding to 0.0.0.0:3000");
diff --git a/src/session_store.rs b/src/session_store.rs
new file mode 100644
index 0000000..b97768b
--- /dev/null
+++ b/src/session_store.rs
@@ -0,0 +1,507 @@
+use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
+use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite};
+
+// NOTE! This code was straight stolen from
+// https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs
+// and used under the terms of the MIT license:
+
+/*
+Copyright 2022 Jacob Rothstein
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
+associated documentation files (the “Software”), to deal in the Software without restriction,
+including without limitation the rights to use, copy, modify, merge, publish, distribute,
+sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial
+portions of the Software.
+
+THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
+NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
+OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+*/
+
+/// sqlx sqlite session store for async-sessions
+///
+/// ```rust
+/// use witch_watch::session_store::SqliteSessionStore;
+/// use async_session::{Session, SessionStore, Result};
+/// use std::time::Duration;
+///
+/// # #[tokio::main]
+/// # async fn main() -> Result {
+/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
+/// store.migrate().await?;
+///
+/// let mut session = Session::new();
+/// session.insert("key", vec![1,2,3]);
+///
+/// let cookie_value = store.store_session(session).await?.unwrap();
+/// let session = store.load_session(cookie_value).await?.unwrap();
+/// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]);
+/// # Ok(()) }
+
+#[derive(Clone, Debug)]
+pub struct SqliteSessionStore {
+ client: SqlitePool,
+ table_name: String,
+}
+
+impl SqliteSessionStore {
+ /// constructs a new SqliteSessionStore from an existing
+ /// sqlx::SqlitePool. the default table name for this session
+ /// store will be "async_sessions". To override this, chain this
+ /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
+ ///
+ /// ```rust
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::Result;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
+ /// let store = SqliteSessionStore::from_client(pool)
+ /// .with_table_name("custom_table_name");
+ /// store.migrate().await;
+ /// # Ok(()) }
+ /// ```
+ pub fn from_client(client: SqlitePool) -> Self {
+ Self {
+ client,
+ table_name: "async_sessions".into(),
+ }
+ }
+
+ /// Constructs a new SqliteSessionStore from a sqlite: database url. note
+ /// that this documentation uses the special `:memory:` sqlite
+ /// database for convenient testing, but a real application would
+ /// use a path like `sqlite:///path/to/database.db`. The default
+ /// table name for this session store will be "async_sessions". To
+ /// override this, either chain with
+ /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
+ /// use
+ /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
+ ///
+ /// ```rust
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::Result;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
+ /// store.migrate().await;
+ /// # Ok(()) }
+ /// ```
+ pub async fn new(database_url: &str) -> sqlx::Result {
+ Ok(Self::from_client(SqlitePool::connect(database_url).await?))
+ }
+
+ /// constructs a new SqliteSessionStore from a sqlite: database url. the
+ /// default table name for this session store will be
+ /// "async_sessions". To override this, either chain with
+ /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
+ /// use
+ /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
+ ///
+ /// ```rust
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::Result;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?;
+ /// store.migrate().await;
+ /// # Ok(()) }
+ /// ```
+ pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result {
+ Ok(Self::new(database_url).await?.with_table_name(table_name))
+ }
+
+ /// Chainable method to add a custom table name. This will panic
+ /// if the table name is not `[a-zA-Z0-9_-]+`.
+ /// ```rust
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::Result;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let store = SqliteSessionStore::new("sqlite::memory:").await?
+ /// .with_table_name("custom_name");
+ /// store.migrate().await;
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// ```should_panic
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::Result;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let store = SqliteSessionStore::new("sqlite::memory:").await?
+ /// .with_table_name("johnny (); drop users;");
+ /// # Ok(()) }
+ /// ```
+ pub fn with_table_name(mut self, table_name: impl AsRef) -> Self {
+ let table_name = table_name.as_ref();
+ if table_name.is_empty()
+ || !table_name
+ .chars()
+ .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
+ {
+ panic!(
+ "table name must be [a-zA-Z0-9_-]+, but {} was not",
+ table_name
+ );
+ }
+
+ self.table_name = table_name.to_owned();
+ self
+ }
+
+ /// Creates a session table if it does not already exist. If it
+ /// does, this will noop, making it safe to call repeatedly on
+ /// store initialization. In the future, this may make
+ /// exactly-once modifications to the schema of the session table
+ /// on breaking releases.
+ /// ```rust
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::{Result, SessionStore, Session};
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
+ /// assert!(store.count().await.is_err());
+ /// store.migrate().await?;
+ /// store.store_session(Session::new()).await?;
+ /// store.migrate().await?; // calling it a second time is safe
+ /// assert_eq!(store.count().await?, 1);
+ /// # Ok(()) }
+ /// ```
+ pub async fn migrate(&self) -> sqlx::Result<()> {
+ log::info!("migrating sessions on `{}`", self.table_name);
+
+ let mut conn = self.client.acquire().await?;
+ sqlx::query(&self.substitute_table_name(
+ r#"
+ CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
+ id TEXT PRIMARY KEY NOT NULL,
+ expires INTEGER NULL,
+ session TEXT NOT NULL
+ )
+ "#,
+ ))
+ .execute(&mut conn)
+ .await?;
+ Ok(())
+ }
+
+ // private utility function because sqlite does not support
+ // parametrized table names
+ fn substitute_table_name(&self, query: &str) -> String {
+ query.replace("%%TABLE_NAME%%", &self.table_name)
+ }
+
+ /// retrieve a connection from the pool
+ async fn connection(&self) -> sqlx::Result> {
+ self.client.acquire().await
+ }
+
+ /// Performs a one-time cleanup task that clears out stale
+ /// (expired) sessions. You may want to call this from cron.
+ /// ```rust
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
+ /// store.migrate().await?;
+ /// let mut session = Session::new();
+ /// session.set_expiry(Utc::now() - Duration::seconds(5));
+ /// store.store_session(session).await?;
+ /// assert_eq!(store.count().await?, 1);
+ /// store.cleanup().await?;
+ /// assert_eq!(store.count().await?, 0);
+ /// # Ok(()) }
+ /// ```
+ pub async fn cleanup(&self) -> sqlx::Result<()> {
+ let mut connection = self.connection().await?;
+ sqlx::query(&self.substitute_table_name(
+ r#"
+ DELETE FROM %%TABLE_NAME%%
+ WHERE expires < ?
+ "#,
+ ))
+ .bind(Utc::now().timestamp())
+ .execute(&mut connection)
+ .await?;
+
+ Ok(())
+ }
+
+ /// retrieves the number of sessions currently stored, including
+ /// expired sessions
+ ///
+ /// ```rust
+ /// # use witch_watch::session_store::SqliteSessionStore;
+ /// # use async_session::{Result, SessionStore, Session};
+ /// # use std::time::Duration;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result {
+ /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
+ /// store.migrate().await?;
+ /// assert_eq!(store.count().await?, 0);
+ /// store.store_session(Session::new()).await?;
+ /// assert_eq!(store.count().await?, 1);
+ /// # Ok(()) }
+ /// ```
+ pub async fn count(&self) -> sqlx::Result {
+ let (count,) =
+ sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
+ .fetch_one(&mut self.connection().await?)
+ .await?;
+
+ Ok(count)
+ }
+}
+
+#[async_trait]
+impl SessionStore for SqliteSessionStore {
+ async fn load_session(&self, cookie_value: String) -> Result
+
+{% endblock %}
From 5de4b9994c586fc508bbdb5edd5bba8de841ed42 Mon Sep 17 00:00:00 2001
From: Joe Ardent
Date: Mon, 29 May 2023 11:13:12 -0700
Subject: [PATCH 12/12] stub out logout
---
src/login.rs | 8 ++++++++
src/main.rs | 3 ++-
2 files changed, 10 insertions(+), 1 deletion(-)
diff --git a/src/login.rs b/src/login.rs
index 4923679..d9d3603 100644
--- a/src/login.rs
+++ b/src/login.rs
@@ -97,3 +97,11 @@ pub async fn post_login(
pub async fn get_login() -> impl IntoResponse {
LoginGet::default()
}
+
+pub async fn get_logout() -> impl IntoResponse {
+ todo!()
+}
+
+pub async fn post_logout() -> impl IntoResponse {
+ todo!()
+}
diff --git a/src/main.rs b/src/main.rs
index 57342f4..39903e5 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -7,7 +7,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use witch_watch::{
db,
generic_handlers::{handle_slash, handle_slash_redir},
- login::{get_login, post_login},
+ login::{get_login, get_logout, post_login, post_logout},
session_store::SqliteSessionStore,
signup::{get_create_user, handle_signup_success, post_create_user},
User,
@@ -52,6 +52,7 @@ async fn main() {
get(handle_signup_success).post(handle_signup_success),
)
.route("/login", get(get_login).post(post_login))
+ .route("/logout", get(get_logout).post(post_logout))
.fallback(handle_slash_redir)
.layer(auth_layer)
.layer(session_layer)