Merge branch 'login'

This commit is contained in:
Joe Ardent 2023-05-29 11:13:32 -07:00
commit 113982ba27
16 changed files with 1098 additions and 925 deletions

710
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,23 +4,25 @@ version = "0.0.1"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
axum = { version = "0.6", features = ["macros", "tracing"] } axum = { version = "0.6", features = ["macros", "headers"] }
askama = { version = "0.12", features = ["with-axum"] } askama = { version = "0.12", features = ["with-axum"] }
askama_axum = "0.3" askama_axum = "0.3"
axum-macros = "0.3" axum-macros = "0.3"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full", "tracing"], default-features = false }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tower = { version = "0.4", features = ["util", "timeout"] } tower = { version = "0.4", features = ["util", "timeout"], default-features = false }
tower-http = { version = "0.4", features = ["add-extension", "trace"] } tower-http = { version = "0.4", features = ["add-extension", "trace"] }
uuid = { version = "1.3", features = ["serde", "v4"] } uuid = { version = "1", features = ["serde", "v4"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
sqlx = { version = "0.5.10", features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time", "uuid"] } sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time", "uuid"] }
argon2 = "0.5" argon2 = "0.5"
rand_core = { version = "0.6", features = ["getrandom"] } rand_core = { version = "0.6", features = ["getrandom"] }
thiserror = "1.0.40" thiserror = "1"
justerror = "1.1.0" justerror = "1"
password-hash = { version = "0.5.0", features = ["std", "getrandom"] } password-hash = { version = "0.5", features = ["std", "getrandom"] }
axum-login = { version = "0.5.0", features = ["sqlite", "sqlx"] } axum-login = { version = "0.5", features = ["sqlite", "sqlx"] }
unicode-segmentation = "1.10.1" unicode-segmentation = "1"
urlencoding = "2.1.2" urlencoding = "2"
async-session = "3"

View File

@ -20,12 +20,12 @@ pub async fn get_pool() -> SqlitePool {
let conn_opts = SqliteConnectOptions::new() let conn_opts = SqliteConnectOptions::new()
.foreign_keys(true) .foreign_keys(true)
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
.filename(&db_filename); .filename(&db_filename)
.busy_timeout(Duration::from_secs(TIMEOUT));
// setup connection pool // setup connection pool
SqlitePoolOptions::new() SqlitePoolOptions::new()
.max_connections(MAX_CONNS) .max_connections(MAX_CONNS)
.connect_timeout(Duration::from_secs(TIMEOUT))
.connect_with(conn_opts) .connect_with(conn_opts)
.await .await
.expect("can't connect to database") .expect("can't connect to database")

View File

@ -1,54 +0,0 @@
use axum::{extract::Form, response::Html};
use serde::Deserialize;
pub(crate) async fn show_form() -> Html<&'static str> {
Html(
r#"
<!doctype html>
<html>
<head></head>
<body>
<form action="/" method="post">
<label for="name">
Enter your name:
<input type="text" name="name">
</label>
<label>
Enter your email:
<input type="text" name="email">
</label>
<input type="submit" value="Subscribe!">
</form>
</body>
</html>
"#,
)
}
#[derive(Deserialize, Debug)]
#[allow(dead_code)]
pub(crate) struct Input {
name: String,
email: String,
}
pub(crate) async fn accept_form(Form(input): Form<Input>) -> Html<String> {
let Input { name, email: _ } = input;
let html = format!(
r#"
<!doctype html>
<html>
<head></head>
<body>
<p>Hi, {}</p>
</body>
</html>
"#,
name
);
Html(html)
}

15
src/generic_handlers.rs Normal file
View File

@ -0,0 +1,15 @@
use axum::response::{IntoResponse, Redirect};
use crate::AuthContext;
pub async fn handle_slash_redir() -> impl IntoResponse {
Redirect::temporary("/")
}
pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
if let Some(user) = auth.current_user {
tracing::debug!("Logged in as: {user}");
} else {
tracing::debug!("Not logged in.")
}
}

View File

@ -1,55 +0,0 @@
use axum::{
async_trait,
extract::{FromRef, FromRequestParts, State},
http::{request::Parts, StatusCode},
};
use sqlx::SqlitePool;
pub async fn using_connection_pool_extractor(
State(pool): State<SqlitePool>,
) -> Result<String, (StatusCode, String)> {
sqlx::query_scalar("select 'hello world from sqlite get'")
.fetch_one(&pool)
.await
.map_err(internal_error)
}
// we can also write a custom extractor that grabs a connection from the pool
// which setup is appropriate depends on your application
pub struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Sqlite>);
#[async_trait]
impl<S> FromRequestParts<S> for DatabaseConnection
where
SqlitePool: FromRef<S>,
S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let pool = SqlitePool::from_ref(state);
let conn = pool.acquire().await.map_err(internal_error)?;
Ok(Self(conn))
}
}
pub async fn using_connection_extractor(
DatabaseConnection(conn): DatabaseConnection,
) -> Result<String, (StatusCode, String)> {
let mut conn = conn;
sqlx::query_scalar("select 'hello world from sqlite post'")
.fetch_one(&mut conn)
.await
.map_err(internal_error)
}
/// Utility function for mapping any error into a `500 Internal Server Error`
/// response.
fn internal_error<E>(err: E) -> (StatusCode, String)
where
E: std::error::Error,
{
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}

View File

@ -1,7 +1,17 @@
#[macro_use] #[macro_use]
extern crate justerror; extern crate justerror;
use axum_login::SqliteStore;
pub use users::User;
use uuid::Uuid;
pub mod db; pub mod db;
pub mod handlers; pub mod generic_handlers;
pub mod login;
pub mod session_store;
pub mod signup;
pub(crate) mod templates; pub(crate) mod templates;
pub mod users; pub mod users;
pub(crate) mod util;
pub type AuthContext = axum_login::extractors::AuthContext<Uuid, User, SqliteStore<User>>;

107
src/login.rs Normal file
View File

@ -0,0 +1,107 @@
use argon2::{
password_hash::{PasswordHash, PasswordVerifier},
Argon2,
};
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Redirect, Response},
Form,
};
use sqlx::SqlitePool;
use crate::{
templates::{LoginGet, LoginPost},
util::form_decode,
AuthContext, User,
};
//-************************************************************************
// Constants
//-************************************************************************
const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1";
//-************************************************************************
// Login error and success types
//-************************************************************************
#[Error]
pub struct LoginError(#[from] LoginErrorKind);
#[Error]
#[non_exhaustive]
pub enum LoginErrorKind {
Internal,
BadPassword,
BadUsername,
Unknown,
}
impl IntoResponse for LoginError {
fn into_response(self) -> Response {
match self.0 {
LoginErrorKind::Unknown => (
StatusCode::INTERNAL_SERVER_ERROR,
"An unknown error occurred; you cursed, brah?",
)
.into_response(),
_ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(),
}
}
}
//-************************************************************************
// Login handlers
//-************************************************************************
/// Handle login queries
#[axum::debug_handler]
pub async fn post_login(
mut auth: AuthContext,
State(pool): State<SqlitePool>,
Form(login): Form<LoginPost>,
) -> Result<impl IntoResponse, LoginError> {
let username = form_decode(&login.username, LoginErrorKind::BadUsername)?;
let username = username.trim();
let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?;
let pw = pw.trim();
let user = User::get(username, &pool)
.await
.map_err(|_| LoginErrorKind::Unknown)?;
let verifier = Argon2::default();
let hash = PasswordHash::new(&user.pwhash).map_err(|_| LoginErrorKind::Internal)?;
match verifier.verify_password(pw.as_bytes(), &hash) {
Ok(_) => {
// log them in and set a session cookie
auth.login(&user)
.await
.map_err(|_| LoginErrorKind::Internal)?;
// update last_seen; maybe this is ok to fail?
sqlx::query(LAST_SEEN_QUERY)
.bind(user.id)
.execute(&pool)
.await
.map_err(|_| LoginErrorKind::Internal)?;
Ok(Redirect::temporary("/"))
}
_ => Err(LoginErrorKind::BadPassword.into()),
}
}
pub async fn get_login() -> impl IntoResponse {
LoginGet::default()
}
pub async fn get_logout() -> impl IntoResponse {
todo!()
}
pub async fn post_logout() -> impl IntoResponse {
todo!()
}

View File

@ -1,10 +1,16 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use axum_login::{axum_sessions::SessionLayer, AuthLayer, SqliteStore};
use rand_core::{OsRng, RngCore};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use witch_watch::{ use witch_watch::{
db, db,
users::{get_create_user, handle_signup_success, post_create_user}, generic_handlers::{handle_slash, handle_slash_redir},
login::{get_login, get_logout, post_login, post_logout},
session_store::SqliteSessionStore,
signup::{get_create_user, handle_signup_success, post_create_user},
User,
}; };
#[tokio::main] #[tokio::main]
@ -19,13 +25,37 @@ async fn main() {
let pool = db::get_pool().await; let pool = db::get_pool().await;
// build our application with some routes let secret = {
let mut bytes = [0u8; 128];
let mut rng = OsRng;
rng.fill_bytes(&mut bytes);
bytes
};
let session_layer = {
let store = SqliteSessionStore::from_client(pool.clone());
store.migrate().await.expect("Could not migrate session DB");
SessionLayer::new(store, &secret).with_secure(true)
};
let auth_layer = {
const QUERY: &str = "select * from witches where id = $1";
let store = SqliteStore::<User>::new(pool.clone()).with_query(QUERY);
AuthLayer::new(store, &secret)
};
let app = Router::new() let app = Router::new()
.route("/", get(handle_slash).post(handle_slash))
.route("/signup", get(get_create_user).post(post_create_user)) .route("/signup", get(get_create_user).post(post_create_user))
.route( .route(
"/signup_success/:id", "/signup_success/:id",
get(handle_signup_success).post(handle_signup_success), get(handle_signup_success).post(handle_signup_success),
) )
.route("/login", get(get_login).post(post_login))
.route("/logout", get(get_logout).post(post_logout))
.fallback(handle_slash_redir)
.layer(auth_layer)
.layer(session_layer)
.with_state(pool); .with_state(pool);
tracing::debug!("binding to 0.0.0.0:3000"); tracing::debug!("binding to 0.0.0.0:3000");

507
src/session_store.rs Normal file
View File

@ -0,0 +1,507 @@
use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite};
// NOTE! This code was straight stolen from
// https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs
// and used under the terms of the MIT license:
/*
Copyright 2022 Jacob Rothstein
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the Software), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.
THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
/// sqlx sqlite session store for async-sessions
///
/// ```rust
/// use witch_watch::session_store::SqliteSessionStore;
/// use async_session::{Session, SessionStore, Result};
/// use std::time::Duration;
///
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
///
/// let mut session = Session::new();
/// session.insert("key", vec![1,2,3]);
///
/// let cookie_value = store.store_session(session).await?.unwrap();
/// let session = store.load_session(cookie_value).await?.unwrap();
/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
/// # Ok(()) }
#[derive(Clone, Debug)]
pub struct SqliteSessionStore {
client: SqlitePool,
table_name: String,
}
impl SqliteSessionStore {
/// constructs a new SqliteSessionStore from an existing
/// sqlx::SqlitePool. the default table name for this session
/// store will be "async_sessions". To override this, chain this
/// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
/// let store = SqliteSessionStore::from_client(pool)
/// .with_table_name("custom_table_name");
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub fn from_client(client: SqlitePool) -> Self {
Self {
client,
table_name: "async_sessions".into(),
}
}
/// Constructs a new SqliteSessionStore from a sqlite: database url. note
/// that this documentation uses the special `:memory:` sqlite
/// database for convenient testing, but a real application would
/// use a path like `sqlite:///path/to/database.db`. The default
/// table name for this session store will be "async_sessions". To
/// override this, either chain with
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
/// use
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub async fn new(database_url: &str) -> sqlx::Result<Self> {
Ok(Self::from_client(SqlitePool::connect(database_url).await?))
}
/// constructs a new SqliteSessionStore from a sqlite: database url. the
/// default table name for this session store will be
/// "async_sessions". To override this, either chain with
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
/// use
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?;
/// store.migrate().await;
/// # Ok(()) }
/// ```
pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
Ok(Self::new(database_url).await?.with_table_name(table_name))
}
/// Chainable method to add a custom table name. This will panic
/// if the table name is not `[a-zA-Z0-9_-]+`.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
/// .with_table_name("custom_name");
/// store.migrate().await;
/// # Ok(()) }
/// ```
///
/// ```should_panic
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::Result;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
/// .with_table_name("johnny (); drop users;");
/// # Ok(()) }
/// ```
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
let table_name = table_name.as_ref();
if table_name.is_empty()
|| !table_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
panic!(
"table name must be [a-zA-Z0-9_-]+, but {} was not",
table_name
);
}
self.table_name = table_name.to_owned();
self
}
/// Creates a session table if it does not already exist. If it
/// does, this will noop, making it safe to call repeatedly on
/// store initialization. In the future, this may make
/// exactly-once modifications to the schema of the session table
/// on breaking releases.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{Result, SessionStore, Session};
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// assert!(store.count().await.is_err());
/// store.migrate().await?;
/// store.store_session(Session::new()).await?;
/// store.migrate().await?; // calling it a second time is safe
/// assert_eq!(store.count().await?, 1);
/// # Ok(()) }
/// ```
pub async fn migrate(&self) -> sqlx::Result<()> {
log::info!("migrating sessions on `{}`", self.table_name);
let mut conn = self.client.acquire().await?;
sqlx::query(&self.substitute_table_name(
r#"
CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
id TEXT PRIMARY KEY NOT NULL,
expires INTEGER NULL,
session TEXT NOT NULL
)
"#,
))
.execute(&mut conn)
.await?;
Ok(())
}
// private utility function because sqlite does not support
// parametrized table names
fn substitute_table_name(&self, query: &str) -> String {
query.replace("%%TABLE_NAME%%", &self.table_name)
}
/// retrieve a connection from the pool
async fn connection(&self) -> sqlx::Result<PoolConnection<Sqlite>> {
self.client.acquire().await
}
/// Performs a one-time cleanup task that clears out stale
/// (expired) sessions. You may want to call this from cron.
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
/// let mut session = Session::new();
/// session.set_expiry(Utc::now() - Duration::seconds(5));
/// store.store_session(session).await?;
/// assert_eq!(store.count().await?, 1);
/// store.cleanup().await?;
/// assert_eq!(store.count().await?, 0);
/// # Ok(()) }
/// ```
pub async fn cleanup(&self) -> sqlx::Result<()> {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%%
WHERE expires < ?
"#,
))
.bind(Utc::now().timestamp())
.execute(&mut connection)
.await?;
Ok(())
}
/// retrieves the number of sessions currently stored, including
/// expired sessions
///
/// ```rust
/// # use witch_watch::session_store::SqliteSessionStore;
/// # use async_session::{Result, SessionStore, Session};
/// # use std::time::Duration;
/// # #[tokio::main]
/// # async fn main() -> Result {
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
/// store.migrate().await?;
/// assert_eq!(store.count().await?, 0);
/// store.store_session(Session::new()).await?;
/// assert_eq!(store.count().await?, 1);
/// # Ok(()) }
/// ```
pub async fn count(&self) -> sqlx::Result<i32> {
let (count,) =
sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
.fetch_one(&mut self.connection().await?)
.await?;
Ok(count)
}
}
#[async_trait]
impl SessionStore for SqliteSessionStore {
async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
let id = Session::id_from_cookie_value(&cookie_value)?;
let mut connection = self.connection().await?;
let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
r#"
SELECT session FROM %%TABLE_NAME%%
WHERE id = ? AND (expires IS NULL OR expires > ?)
"#,
))
.bind(&id)
.bind(Utc::now().timestamp())
.fetch_optional(&mut connection)
.await?;
Ok(result
.map(|(session,)| serde_json::from_str(&session))
.transpose()?)
}
async fn store_session(&self, session: Session) -> Result<Option<String>> {
let id = session.id();
let string = serde_json::to_string(&session)?;
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
INSERT INTO %%TABLE_NAME%%
(id, session, expires) VALUES (?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
expires = excluded.expires,
session = excluded.session
"#,
))
.bind(id)
.bind(&string)
.bind(session.expiry().map(|expiry| expiry.timestamp()))
.execute(&mut connection)
.await?;
Ok(session.into_cookie_value())
}
async fn destroy_session(&self, session: Session) -> Result {
let id = session.id();
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%% WHERE id = ?
"#,
))
.bind(id)
.execute(&mut connection)
.await?;
Ok(())
}
async fn clear_store(&self) -> Result {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
DELETE FROM %%TABLE_NAME%%
"#,
))
.execute(&mut connection)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
async fn test_store() -> SqliteSessionStore {
let store = SqliteSessionStore::new("sqlite::memory:")
.await
.expect("building a sqlite :memory: SqliteSessionStore");
store
.migrate()
.await
.expect("migrating a brand new :memory: SqliteSessionStore");
store
}
#[tokio::test]
async fn creating_a_new_session_with_no_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert_eq!(expires, None);
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
Ok(())
}
#[tokio::test]
async fn updating_a_session() -> Result {
let store = test_store().await;
let mut session = Session::new();
let original_id = session.id().to_owned();
session.insert("key", "value")?;
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
session.insert("key", "other value")?;
assert_eq!(None, store.store_session(session).await?);
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.get::<String>("key").unwrap(), "other value");
let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(original_id, id);
Ok(())
}
#[tokio::test]
async fn updating_a_session_extending_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone();
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &original_expires);
session.expire_in(Duration::from_secs(20));
let new_expires = session.expiry().unwrap().clone();
store.store_session(session).await?;
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &new_expires);
let (id, expires, count): (String, i64, i64) =
sqlx::query_as("select id, expires, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(expires, new_expires.timestamp());
assert_eq!(original_id, id);
Ok(())
}
#[tokio::test]
async fn creating_a_new_session_with_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert!(expires.unwrap() > Utc::now().timestamp());
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(None, store.load_session(cookie_value).await?);
Ok(())
}
#[tokio::test]
async fn destroying_a_single_session() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
let cookie = store.store_session(Session::new()).await?.unwrap();
assert_eq!(4, store.count().await?);
let session = store.load_session(cookie.clone()).await?.unwrap();
store.destroy_session(session.clone()).await.unwrap();
assert_eq!(None, store.load_session(cookie).await?);
assert_eq!(3, store.count().await?);
// // attempting to destroy the session again is not an error
// assert!(store.destroy_session(session).await.is_ok());
Ok(())
}
#[tokio::test]
async fn clearing_the_whole_store() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
assert_eq!(3, store.count().await?);
store.clear_store().await.unwrap();
assert_eq!(0, store.count().await?);
Ok(())
}
}

220
src/signup.rs Normal file
View File

@ -0,0 +1,220 @@
use argon2::{
password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
Argon2,
};
use askama::Template;
use axum::{
extract::{Form, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use sqlx::{query_as, SqlitePool};
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid;
use crate::{templates::CreateUser, User};
const CREATE_QUERY: &str =
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
const ID_QUERY: &str = "select * from witches where id = $1";
//-************************************************************************
// Result types for user creation
//-************************************************************************
#[derive(Debug, Clone, Template)]
#[template(path = "signup_success.html")]
pub struct CreateUserSuccess(User);
#[Error(desc = "Could not create user.")]
#[non_exhaustive]
pub struct CreateUserError(#[from] CreateUserErrorKind);
impl IntoResponse for CreateUserError {
fn into_response(self) -> askama_axum::Response {
match self.0 {
CreateUserErrorKind::UnknownDBError => {
(StatusCode::INTERNAL_SERVER_ERROR, format!("{self}")).into_response()
}
_ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(),
}
}
}
#[Error]
#[non_exhaustive]
pub enum CreateUserErrorKind {
AlreadyExists,
#[error(desc = "Usernames must be between 1 and 20 non-whitespace characters long")]
BadUsername,
PasswordMismatch,
#[error(desc = "Password must have at least 4 and at most 50 characters")]
BadPassword,
#[error(desc = "Display name must be less than 100 characters long")]
BadDisplayname,
BadEmail,
MissingFields,
UnknownDBError,
}
//-************************************************************************
// User creation route handlers
//-************************************************************************
/// Get Handler: displays the form to create a user
pub async fn get_create_user() -> CreateUser {
CreateUser::default()
}
/// Post Handler: validates form values and calls the actual, private user
/// creation function
#[axum::debug_handler]
pub async fn post_create_user(
State(pool): State<SqlitePool>,
Form(signup): Form<CreateUser>,
) -> Result<impl IntoResponse, CreateUserError> {
let username = &signup.username;
let displayname = &signup.displayname;
let email = &signup.email;
let password = &signup.password;
let verify = &signup.pw_verify;
let username = username.trim();
let name_len = username.graphemes(true).size_hint().1.unwrap();
// we are not ascii exclusivists around here
if !(1..=20).contains(&name_len) {
return Err(CreateUserErrorKind::BadUsername.into());
}
if password != verify {
return Err(CreateUserErrorKind::PasswordMismatch.into());
}
let password = urlencoding::decode(password)
.map_err(|_| CreateUserErrorKind::BadPassword)?
.to_string();
let password = password.trim();
let password = password.as_bytes();
if !(4..=50).contains(&password.len()) {
return Err(CreateUserErrorKind::BadPassword.into());
}
let displayname = if let Some(dn) = displayname {
let dn = urlencoding::decode(dn)
.map_err(|_| CreateUserErrorKind::BadDisplayname)?
.to_string()
.trim()
.to_string();
if dn.graphemes(true).size_hint().1.unwrap() > 100 {
return Err(CreateUserErrorKind::BadDisplayname.into());
}
Some(dn)
} else {
None
};
let displayname = &displayname;
// TODO(2023-05-17): validate email
let email = if let Some(email) = email {
let email = urlencoding::decode(email)
.map_err(|_| CreateUserErrorKind::BadEmail)?
.to_string();
Some(email)
} else {
None
};
let email = &email;
let user = create_user(username, displayname, email, password, &pool).await?;
tracing::debug!("created {user:?}");
let id = user.id.as_simple().to_string();
let location = format!("/signup_success/{id}");
let resp = axum::response::Redirect::temporary(&location);
Ok(resp)
}
/// Generic handler for successful signup
pub async fn handle_signup_success(
Path(id): Path<String>,
State(pool): State<SqlitePool>,
) -> Response {
let id = id.trim();
let user: User = {
let id = Uuid::try_parse(id).unwrap_or_default();
query_as(ID_QUERY)
.bind(id)
.fetch_one(&pool)
.await
.unwrap_or_default()
};
let mut resp = CreateUserSuccess(user.clone()).into_response();
if user.username.is_empty() || id.is_empty() {
// redirect to front page if we got here without a valid witch ID
*resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;
resp.headers_mut().insert("Location", "/".parse().unwrap());
}
resp
}
//-************************************************************************
// private fns
//-************************************************************************
async fn create_user(
username: &str,
displayname: &Option<String>,
email: &Option<String>,
password: &[u8],
pool: &SqlitePool,
) -> Result<User, CreateUserError> {
// Argon2 with default params (Argon2id v19)
let argon2 = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let pwhash = argon2
.hash_password(password, &salt)
.unwrap() // safe to unwrap, we know the salt is valid
.to_string();
let id = Uuid::new_v4();
let res = sqlx::query(CREATE_QUERY)
.bind(id)
.bind(username)
.bind(displayname)
.bind(email)
.bind(&pwhash)
.execute(pool)
.await;
match res {
Ok(_) => {
let user = User {
id,
username: username.to_string(),
displayname: displayname.to_owned(),
email: email.to_owned(),
last_seen: None,
pwhash,
};
Ok(user)
}
Err(sqlx::Error::Database(db)) => {
if let Some(exit) = db.code() {
let exit = exit.parse().unwrap_or(0u32);
// https://www.sqlite.org/rescode.html codes for unique constraint violations:
if exit == 2067u32 || exit == 1555 {
Err(CreateUserErrorKind::AlreadyExists.into())
} else {
Err(CreateUserErrorKind::UnknownDBError.into())
}
} else {
Err(CreateUserErrorKind::UnknownDBError.into())
}
}
_ => Err(CreateUserErrorKind::UnknownDBError.into()),
}
}

View File

@ -1,7 +1,7 @@
use askama::Template; use askama::Template;
use serde::Deserialize; use serde::{Deserialize, Serialize};
#[derive(Debug, Default, Template, Deserialize)] #[derive(Debug, Default, Template, Deserialize, Serialize)]
#[template(path = "signup.html")] #[template(path = "signup.html")]
pub struct CreateUser { pub struct CreateUser {
pub username: String, pub username: String,
@ -10,3 +10,17 @@ pub struct CreateUser {
pub password: String, pub password: String,
pub pw_verify: String, pub pw_verify: String,
} }
#[derive(Debug, Default, Template, Deserialize, Serialize)]
#[template(path = "login_post.html")]
pub struct LoginPost {
pub username: String,
pub password: String,
}
#[derive(Debug, Default, Template, Deserialize, Serialize)]
#[template(path = "login_get.html")]
pub struct LoginGet {
pub username: String,
pub password: String,
}

View File

@ -1,33 +1,19 @@
use std::fmt::Display; use std::fmt::Display;
use argon2::{ use axum_login::{secrecy::SecretVec, AuthUser};
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, use sqlx::SqlitePool;
Argon2,
};
use askama::Template;
use axum::{
extract::{Form, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use sqlx::{sqlite::SqliteRow, Row, SqlitePool};
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid; use uuid::Uuid;
use crate::templates::CreateUser; const USERNAME_QUERY: &str = "select * from witches where username = $1";
const CREATE_QUERY: &str = #[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow)]
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
const ID_QUERY: &str = "select * from witches where id = $1";
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct User { pub struct User {
id: Uuid, pub id: Uuid,
username: String, pub username: String,
displayname: Option<String>, pub displayname: Option<String>,
email: Option<String>, pub email: Option<String>,
last_seen: Option<i64>, pub last_seen: Option<i64>,
pub(crate) pwhash: String,
} }
impl Display for User { impl Display for User {
@ -43,208 +29,21 @@ impl Display for User {
} }
} }
#[derive(Debug, Clone, Template)] impl AuthUser<Uuid> for User {
#[template(path = "signup_success.html")] fn get_id(&self) -> Uuid {
pub struct CreateUserSuccess(User); self.id
}
impl sqlx::FromRow<'_, SqliteRow> for User { fn get_password_hash(&self) -> SecretVec<u8> {
fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> { SecretVec::new(self.pwhash.as_bytes().to_vec())
let bytes: Vec<u8> = row.get("id");
let bytes = bytes.as_slice();
let bytes: [u8; 16] = bytes.try_into().unwrap();
let id = Uuid::from_bytes_le(bytes);
let username: String = row.get("username");
let displayname: Option<String> = row.get("displayname");
let last_seen: Option<i64> = row.get("last_seen");
let email: Option<String> = row.get("email");
Ok(Self {
id,
username,
displayname,
email,
last_seen,
})
} }
} }
/// Get Handler: displays the form to create a user impl User {
pub async fn get_create_user() -> CreateUser { pub async fn get(username: &str, db: &SqlitePool) -> Result<User, impl std::error::Error> {
CreateUser::default() sqlx::query_as(USERNAME_QUERY)
} .bind(username)
.fetch_one(db)
/// Post Handler: validates form values and calls the actual, private user
/// creation function
#[axum::debug_handler]
pub async fn post_create_user(
State(pool): State<SqlitePool>,
Form(signup): Form<CreateUser>,
) -> Result<Response, CreateUserError> {
let username = &signup.username;
let displayname = &signup.displayname;
let email = &signup.email;
let password = &signup.password;
let verify = &signup.pw_verify;
let username = username.trim();
let name_len = username.graphemes(true).size_hint().1.unwrap();
// we are not ascii exclusivists around here
if !(1..=20).contains(&name_len) {
return Err(CreateUserErrorKind::BadUsername.into());
}
if let Some(ref dn) = displayname {
if dn.len() > 50 {
return Err(CreateUserErrorKind::BadDisplayname.into());
}
}
if password != verify {
return Err(CreateUserErrorKind::PasswordMismatch.into());
}
let password = urlencoding::decode(password)
.map_err(|_| CreateUserErrorKind::BadPassword)?
.to_string();
let password = password.as_bytes();
let displayname = if let Some(dn) = displayname {
let dn = urlencoding::decode(dn)
.map_err(|_| CreateUserErrorKind::BadDisplayname)?
.to_string();
Some(dn)
} else {
None
};
let displayname = &displayname;
// TODO(2023-05-17): validate email
let email = if let Some(email) = email {
let email = urlencoding::decode(email)
.map_err(|_| CreateUserErrorKind::BadEmail)?
.to_string();
Some(email)
} else {
None
};
let email = &email;
let user = create_user(username, displayname, email, password, &pool).await?;
tracing::debug!("created {user:?}");
let id = user.id.simple().to_string();
let location = format!("/signup_success/{id}");
let resp = axum::response::Redirect::temporary(&location).into_response();
Ok(resp)
}
/// Get handler for successful signup
pub async fn handle_signup_success(
Path(id): Path<String>,
State(pool): State<SqlitePool>,
) -> Response {
let user: User = {
let id = id;
let id = Uuid::try_parse(&id).unwrap_or_default();
let id_bytes = id.to_bytes_le();
sqlx::query_as(ID_QUERY)
.bind(id_bytes.as_slice())
.fetch_one(&pool)
.await .await
.unwrap_or_default()
};
let mut resp = CreateUserSuccess(user.clone()).into_response();
if user.username.is_empty() {
// redirect to front page if we got here without a valid witch header
*resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;
resp.headers_mut().insert("Location", "/".parse().unwrap());
}
resp
}
async fn create_user(
username: &str,
displayname: &Option<String>,
email: &Option<String>,
password: &[u8],
pool: &SqlitePool,
) -> Result<User, CreateUserError> {
// Argon2 with default params (Argon2id v19)
let argon2 = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let pwhash = argon2
.hash_password(password, &salt)
.unwrap() // safe to unwrap, we know the salt is valid
.to_string();
let id = Uuid::new_v4();
let id_bytes = id.to_bytes_le();
let id_bytes = id_bytes.as_slice();
let res = sqlx::query(CREATE_QUERY)
.bind(id_bytes)
.bind(username)
.bind(displayname)
.bind(email)
.bind(pwhash)
.execute(pool)
.await;
match res {
Ok(_) => {
let user = User {
id,
username: username.to_string(),
displayname: displayname.to_owned(),
email: email.to_owned(),
last_seen: None,
};
Ok(user)
}
Err(sqlx::Error::Database(db)) => {
if let Some(exit) = db.code() {
let exit = exit.parse().unwrap_or(0u32);
// https://www.sqlite.org/rescode.html codes for unique constraint violations:
if exit == 2067u32 || exit == 1555 {
Err(CreateUserErrorKind::AlreadyExists.into())
} else {
Err(CreateUserErrorKind::UnknownDBError.into())
}
} else {
Err(CreateUserErrorKind::UnknownDBError.into())
}
}
_ => Err(CreateUserErrorKind::UnknownDBError.into()),
} }
} }
#[Error(desc = "Could not create user.")]
#[non_exhaustive]
pub struct CreateUserError(#[from] CreateUserErrorKind);
impl IntoResponse for CreateUserError {
fn into_response(self) -> askama_axum::Response {
match self.0 {
CreateUserErrorKind::UnknownDBError => {
(StatusCode::INTERNAL_SERVER_ERROR, format!("{self}")).into_response()
}
_ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(),
}
}
}
#[Error]
#[non_exhaustive]
pub enum CreateUserErrorKind {
AlreadyExists,
#[error(desc = "Usernames must be between 1 and 20 non-whitespace characters long")]
BadUsername,
PasswordMismatch,
BadPassword,
BadDisplayname,
BadEmail,
MissingFields,
UnknownDBError,
}

3
src/util.rs Normal file
View File

@ -0,0 +1,3 @@
pub fn form_decode<E: std::error::Error>(input: &str, err: E) -> Result<String, E> {
Ok(urlencoding::decode(input).map_err(|_| err)?.into_owned())
}

17
templates/login_get.html Normal file
View File

@ -0,0 +1,17 @@
{% extends "base.html" %}
{% block title %}Login to Witch Watch, Bish{% endblock %}
{% block content %}
<p>
<form action="/login" enctype="application/x-www-form-urlencoded" method="post">
<label for="username">Username</label>
<input type="text" name="username" id="username" minlength="1" maxlength="20" required></br>
<label for="password">Password</label>
<input type="password" name="password" id="password" required></br>
<input type="submit" value="Signup">
</form>
</p>
{% endblock %}

View File