Merge branch 'login'
This commit is contained in:
commit
113982ba27
File diff suppressed because it is too large
Load Diff
24
Cargo.toml
24
Cargo.toml
|
@ -4,23 +4,25 @@ version = "0.0.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = { version = "0.6", features = ["macros", "tracing"] }
|
axum = { version = "0.6", features = ["macros", "headers"] }
|
||||||
askama = { version = "0.12", features = ["with-axum"] }
|
askama = { version = "0.12", features = ["with-axum"] }
|
||||||
askama_axum = "0.3"
|
askama_axum = "0.3"
|
||||||
axum-macros = "0.3"
|
axum-macros = "0.3"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full", "tracing"], default-features = false }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
tower = { version = "0.4", features = ["util", "timeout"] }
|
tower = { version = "0.4", features = ["util", "timeout"], default-features = false }
|
||||||
tower-http = { version = "0.4", features = ["add-extension", "trace"] }
|
tower-http = { version = "0.4", features = ["add-extension", "trace"] }
|
||||||
uuid = { version = "1.3", features = ["serde", "v4"] }
|
uuid = { version = "1", features = ["serde", "v4"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
sqlx = { version = "0.5.10", features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time", "uuid"] }
|
sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time", "uuid"] }
|
||||||
argon2 = "0.5"
|
argon2 = "0.5"
|
||||||
rand_core = { version = "0.6", features = ["getrandom"] }
|
rand_core = { version = "0.6", features = ["getrandom"] }
|
||||||
thiserror = "1.0.40"
|
thiserror = "1"
|
||||||
justerror = "1.1.0"
|
justerror = "1"
|
||||||
password-hash = { version = "0.5.0", features = ["std", "getrandom"] }
|
password-hash = { version = "0.5", features = ["std", "getrandom"] }
|
||||||
axum-login = { version = "0.5.0", features = ["sqlite", "sqlx"] }
|
axum-login = { version = "0.5", features = ["sqlite", "sqlx"] }
|
||||||
unicode-segmentation = "1.10.1"
|
unicode-segmentation = "1"
|
||||||
urlencoding = "2.1.2"
|
urlencoding = "2"
|
||||||
|
async-session = "3"
|
||||||
|
|
||||||
|
|
|
@ -20,12 +20,12 @@ pub async fn get_pool() -> SqlitePool {
|
||||||
let conn_opts = SqliteConnectOptions::new()
|
let conn_opts = SqliteConnectOptions::new()
|
||||||
.foreign_keys(true)
|
.foreign_keys(true)
|
||||||
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
|
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
|
||||||
.filename(&db_filename);
|
.filename(&db_filename)
|
||||||
|
.busy_timeout(Duration::from_secs(TIMEOUT));
|
||||||
|
|
||||||
// setup connection pool
|
// setup connection pool
|
||||||
SqlitePoolOptions::new()
|
SqlitePoolOptions::new()
|
||||||
.max_connections(MAX_CONNS)
|
.max_connections(MAX_CONNS)
|
||||||
.connect_timeout(Duration::from_secs(TIMEOUT))
|
|
||||||
.connect_with(conn_opts)
|
.connect_with(conn_opts)
|
||||||
.await
|
.await
|
||||||
.expect("can't connect to database")
|
.expect("can't connect to database")
|
||||||
|
|
54
src/form.rs
54
src/form.rs
|
@ -1,54 +0,0 @@
|
||||||
use axum::{extract::Form, response::Html};
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
pub(crate) async fn show_form() -> Html<&'static str> {
|
|
||||||
Html(
|
|
||||||
r#"
|
|
||||||
<!doctype html>
|
|
||||||
<html>
|
|
||||||
<head></head>
|
|
||||||
<body>
|
|
||||||
<form action="/" method="post">
|
|
||||||
<label for="name">
|
|
||||||
Enter your name:
|
|
||||||
<input type="text" name="name">
|
|
||||||
</label>
|
|
||||||
|
|
||||||
<label>
|
|
||||||
Enter your email:
|
|
||||||
<input type="text" name="email">
|
|
||||||
</label>
|
|
||||||
|
|
||||||
<input type="submit" value="Subscribe!">
|
|
||||||
</form>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub(crate) struct Input {
|
|
||||||
name: String,
|
|
||||||
email: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn accept_form(Form(input): Form<Input>) -> Html<String> {
|
|
||||||
let Input { name, email: _ } = input;
|
|
||||||
let html = format!(
|
|
||||||
r#"
|
|
||||||
<!doctype html>
|
|
||||||
<html>
|
|
||||||
<head></head>
|
|
||||||
<body>
|
|
||||||
<p>Hi, {}</p>
|
|
||||||
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"#,
|
|
||||||
name
|
|
||||||
);
|
|
||||||
|
|
||||||
Html(html)
|
|
||||||
}
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
use axum::response::{IntoResponse, Redirect};
|
||||||
|
|
||||||
|
use crate::AuthContext;
|
||||||
|
|
||||||
|
pub async fn handle_slash_redir() -> impl IntoResponse {
|
||||||
|
Redirect::temporary("/")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
|
||||||
|
if let Some(user) = auth.current_user {
|
||||||
|
tracing::debug!("Logged in as: {user}");
|
||||||
|
} else {
|
||||||
|
tracing::debug!("Not logged in.")
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,55 +0,0 @@
|
||||||
use axum::{
|
|
||||||
async_trait,
|
|
||||||
extract::{FromRef, FromRequestParts, State},
|
|
||||||
http::{request::Parts, StatusCode},
|
|
||||||
};
|
|
||||||
use sqlx::SqlitePool;
|
|
||||||
|
|
||||||
pub async fn using_connection_pool_extractor(
|
|
||||||
State(pool): State<SqlitePool>,
|
|
||||||
) -> Result<String, (StatusCode, String)> {
|
|
||||||
sqlx::query_scalar("select 'hello world from sqlite get'")
|
|
||||||
.fetch_one(&pool)
|
|
||||||
.await
|
|
||||||
.map_err(internal_error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// we can also write a custom extractor that grabs a connection from the pool
|
|
||||||
// which setup is appropriate depends on your application
|
|
||||||
pub struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Sqlite>);
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<S> FromRequestParts<S> for DatabaseConnection
|
|
||||||
where
|
|
||||||
SqlitePool: FromRef<S>,
|
|
||||||
S: Send + Sync,
|
|
||||||
{
|
|
||||||
type Rejection = (StatusCode, String);
|
|
||||||
|
|
||||||
async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
|
||||||
let pool = SqlitePool::from_ref(state);
|
|
||||||
|
|
||||||
let conn = pool.acquire().await.map_err(internal_error)?;
|
|
||||||
|
|
||||||
Ok(Self(conn))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn using_connection_extractor(
|
|
||||||
DatabaseConnection(conn): DatabaseConnection,
|
|
||||||
) -> Result<String, (StatusCode, String)> {
|
|
||||||
let mut conn = conn;
|
|
||||||
sqlx::query_scalar("select 'hello world from sqlite post'")
|
|
||||||
.fetch_one(&mut conn)
|
|
||||||
.await
|
|
||||||
.map_err(internal_error)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Utility function for mapping any error into a `500 Internal Server Error`
|
|
||||||
/// response.
|
|
||||||
fn internal_error<E>(err: E) -> (StatusCode, String)
|
|
||||||
where
|
|
||||||
E: std::error::Error,
|
|
||||||
{
|
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
|
|
||||||
}
|
|
12
src/lib.rs
12
src/lib.rs
|
@ -1,7 +1,17 @@
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate justerror;
|
extern crate justerror;
|
||||||
|
|
||||||
|
use axum_login::SqliteStore;
|
||||||
|
pub use users::User;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod handlers;
|
pub mod generic_handlers;
|
||||||
|
pub mod login;
|
||||||
|
pub mod session_store;
|
||||||
|
pub mod signup;
|
||||||
pub(crate) mod templates;
|
pub(crate) mod templates;
|
||||||
pub mod users;
|
pub mod users;
|
||||||
|
pub(crate) mod util;
|
||||||
|
|
||||||
|
pub type AuthContext = axum_login::extractors::AuthContext<Uuid, User, SqliteStore<User>>;
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
use argon2::{
|
||||||
|
password_hash::{PasswordHash, PasswordVerifier},
|
||||||
|
Argon2,
|
||||||
|
};
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Redirect, Response},
|
||||||
|
Form,
|
||||||
|
};
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
templates::{LoginGet, LoginPost},
|
||||||
|
util::form_decode,
|
||||||
|
AuthContext, User,
|
||||||
|
};
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// Constants
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1";
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// Login error and success types
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
#[Error]
|
||||||
|
pub struct LoginError(#[from] LoginErrorKind);
|
||||||
|
|
||||||
|
#[Error]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum LoginErrorKind {
|
||||||
|
Internal,
|
||||||
|
BadPassword,
|
||||||
|
BadUsername,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for LoginError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
match self.0 {
|
||||||
|
LoginErrorKind::Unknown => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"An unknown error occurred; you cursed, brah?",
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
_ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// Login handlers
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
/// Handle login queries
|
||||||
|
#[axum::debug_handler]
|
||||||
|
pub async fn post_login(
|
||||||
|
mut auth: AuthContext,
|
||||||
|
State(pool): State<SqlitePool>,
|
||||||
|
Form(login): Form<LoginPost>,
|
||||||
|
) -> Result<impl IntoResponse, LoginError> {
|
||||||
|
let username = form_decode(&login.username, LoginErrorKind::BadUsername)?;
|
||||||
|
let username = username.trim();
|
||||||
|
|
||||||
|
let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?;
|
||||||
|
let pw = pw.trim();
|
||||||
|
|
||||||
|
let user = User::get(username, &pool)
|
||||||
|
.await
|
||||||
|
.map_err(|_| LoginErrorKind::Unknown)?;
|
||||||
|
|
||||||
|
let verifier = Argon2::default();
|
||||||
|
let hash = PasswordHash::new(&user.pwhash).map_err(|_| LoginErrorKind::Internal)?;
|
||||||
|
match verifier.verify_password(pw.as_bytes(), &hash) {
|
||||||
|
Ok(_) => {
|
||||||
|
// log them in and set a session cookie
|
||||||
|
auth.login(&user)
|
||||||
|
.await
|
||||||
|
.map_err(|_| LoginErrorKind::Internal)?;
|
||||||
|
|
||||||
|
// update last_seen; maybe this is ok to fail?
|
||||||
|
sqlx::query(LAST_SEEN_QUERY)
|
||||||
|
.bind(user.id)
|
||||||
|
.execute(&pool)
|
||||||
|
.await
|
||||||
|
.map_err(|_| LoginErrorKind::Internal)?;
|
||||||
|
|
||||||
|
Ok(Redirect::temporary("/"))
|
||||||
|
}
|
||||||
|
_ => Err(LoginErrorKind::BadPassword.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_login() -> impl IntoResponse {
|
||||||
|
LoginGet::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_logout() -> impl IntoResponse {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn post_logout() -> impl IntoResponse {
|
||||||
|
todo!()
|
||||||
|
}
|
34
src/main.rs
34
src/main.rs
|
@ -1,10 +1,16 @@
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use axum::{routing::get, Router};
|
use axum::{routing::get, Router};
|
||||||
|
use axum_login::{axum_sessions::SessionLayer, AuthLayer, SqliteStore};
|
||||||
|
use rand_core::{OsRng, RngCore};
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
use witch_watch::{
|
use witch_watch::{
|
||||||
db,
|
db,
|
||||||
users::{get_create_user, handle_signup_success, post_create_user},
|
generic_handlers::{handle_slash, handle_slash_redir},
|
||||||
|
login::{get_login, get_logout, post_login, post_logout},
|
||||||
|
session_store::SqliteSessionStore,
|
||||||
|
signup::{get_create_user, handle_signup_success, post_create_user},
|
||||||
|
User,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -19,13 +25,37 @@ async fn main() {
|
||||||
|
|
||||||
let pool = db::get_pool().await;
|
let pool = db::get_pool().await;
|
||||||
|
|
||||||
// build our application with some routes
|
let secret = {
|
||||||
|
let mut bytes = [0u8; 128];
|
||||||
|
let mut rng = OsRng;
|
||||||
|
rng.fill_bytes(&mut bytes);
|
||||||
|
bytes
|
||||||
|
};
|
||||||
|
|
||||||
|
let session_layer = {
|
||||||
|
let store = SqliteSessionStore::from_client(pool.clone());
|
||||||
|
store.migrate().await.expect("Could not migrate session DB");
|
||||||
|
SessionLayer::new(store, &secret).with_secure(true)
|
||||||
|
};
|
||||||
|
|
||||||
|
let auth_layer = {
|
||||||
|
const QUERY: &str = "select * from witches where id = $1";
|
||||||
|
let store = SqliteStore::<User>::new(pool.clone()).with_query(QUERY);
|
||||||
|
AuthLayer::new(store, &secret)
|
||||||
|
};
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
.route("/", get(handle_slash).post(handle_slash))
|
||||||
.route("/signup", get(get_create_user).post(post_create_user))
|
.route("/signup", get(get_create_user).post(post_create_user))
|
||||||
.route(
|
.route(
|
||||||
"/signup_success/:id",
|
"/signup_success/:id",
|
||||||
get(handle_signup_success).post(handle_signup_success),
|
get(handle_signup_success).post(handle_signup_success),
|
||||||
)
|
)
|
||||||
|
.route("/login", get(get_login).post(post_login))
|
||||||
|
.route("/logout", get(get_logout).post(post_logout))
|
||||||
|
.fallback(handle_slash_redir)
|
||||||
|
.layer(auth_layer)
|
||||||
|
.layer(session_layer)
|
||||||
.with_state(pool);
|
.with_state(pool);
|
||||||
|
|
||||||
tracing::debug!("binding to 0.0.0.0:3000");
|
tracing::debug!("binding to 0.0.0.0:3000");
|
||||||
|
|
|
@ -0,0 +1,507 @@
|
||||||
|
use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
|
||||||
|
use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite};
|
||||||
|
|
||||||
|
// NOTE! This code was straight stolen from
|
||||||
|
// https://github.com/jbr/async-sqlx-session/blob/30d00bed44ab2034082698f098eba48b21600f36/src/sqlite.rs
|
||||||
|
// and used under the terms of the MIT license:
|
||||||
|
|
||||||
|
/*
|
||||||
|
Copyright 2022 Jacob Rothstein
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||||
|
associated documentation files (the “Software”), to deal in the Software without restriction,
|
||||||
|
including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||||
|
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial
|
||||||
|
portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||||
|
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
|
||||||
|
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
|
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/// sqlx sqlite session store for async-sessions
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// use async_session::{Session, SessionStore, Result};
|
||||||
|
/// use std::time::Duration;
|
||||||
|
///
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
||||||
|
/// store.migrate().await?;
|
||||||
|
///
|
||||||
|
/// let mut session = Session::new();
|
||||||
|
/// session.insert("key", vec![1,2,3]);
|
||||||
|
///
|
||||||
|
/// let cookie_value = store.store_session(session).await?.unwrap();
|
||||||
|
/// let session = store.load_session(cookie_value).await?.unwrap();
|
||||||
|
/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
|
||||||
|
/// # Ok(()) }
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct SqliteSessionStore {
|
||||||
|
client: SqlitePool,
|
||||||
|
table_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SqliteSessionStore {
|
||||||
|
/// constructs a new SqliteSessionStore from an existing
|
||||||
|
/// sqlx::SqlitePool. the default table name for this session
|
||||||
|
/// store will be "async_sessions". To override this, chain this
|
||||||
|
/// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::Result;
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||||
|
/// let store = SqliteSessionStore::from_client(pool)
|
||||||
|
/// .with_table_name("custom_table_name");
|
||||||
|
/// store.migrate().await;
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub fn from_client(client: SqlitePool) -> Self {
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
table_name: "async_sessions".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs a new SqliteSessionStore from a sqlite: database url. note
|
||||||
|
/// that this documentation uses the special `:memory:` sqlite
|
||||||
|
/// database for convenient testing, but a real application would
|
||||||
|
/// use a path like `sqlite:///path/to/database.db`. The default
|
||||||
|
/// table name for this session store will be "async_sessions". To
|
||||||
|
/// override this, either chain with
|
||||||
|
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
|
||||||
|
/// use
|
||||||
|
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::Result;
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
||||||
|
/// store.migrate().await;
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub async fn new(database_url: &str) -> sqlx::Result<Self> {
|
||||||
|
Ok(Self::from_client(SqlitePool::connect(database_url).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// constructs a new SqliteSessionStore from a sqlite: database url. the
|
||||||
|
/// default table name for this session store will be
|
||||||
|
/// "async_sessions". To override this, either chain with
|
||||||
|
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
|
||||||
|
/// use
|
||||||
|
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::Result;
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?;
|
||||||
|
/// store.migrate().await;
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
|
||||||
|
Ok(Self::new(database_url).await?.with_table_name(table_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chainable method to add a custom table name. This will panic
|
||||||
|
/// if the table name is not `[a-zA-Z0-9_-]+`.
|
||||||
|
/// ```rust
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::Result;
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
|
||||||
|
/// .with_table_name("custom_name");
|
||||||
|
/// store.migrate().await;
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ```should_panic
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::Result;
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
|
||||||
|
/// .with_table_name("johnny (); drop users;");
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
|
||||||
|
let table_name = table_name.as_ref();
|
||||||
|
if table_name.is_empty()
|
||||||
|
|| !table_name
|
||||||
|
.chars()
|
||||||
|
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
|
||||||
|
{
|
||||||
|
panic!(
|
||||||
|
"table name must be [a-zA-Z0-9_-]+, but {} was not",
|
||||||
|
table_name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.table_name = table_name.to_owned();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a session table if it does not already exist. If it
|
||||||
|
/// does, this will noop, making it safe to call repeatedly on
|
||||||
|
/// store initialization. In the future, this may make
|
||||||
|
/// exactly-once modifications to the schema of the session table
|
||||||
|
/// on breaking releases.
|
||||||
|
/// ```rust
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::{Result, SessionStore, Session};
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
||||||
|
/// assert!(store.count().await.is_err());
|
||||||
|
/// store.migrate().await?;
|
||||||
|
/// store.store_session(Session::new()).await?;
|
||||||
|
/// store.migrate().await?; // calling it a second time is safe
|
||||||
|
/// assert_eq!(store.count().await?, 1);
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub async fn migrate(&self) -> sqlx::Result<()> {
|
||||||
|
log::info!("migrating sessions on `{}`", self.table_name);
|
||||||
|
|
||||||
|
let mut conn = self.client.acquire().await?;
|
||||||
|
sqlx::query(&self.substitute_table_name(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
|
||||||
|
id TEXT PRIMARY KEY NOT NULL,
|
||||||
|
expires INTEGER NULL,
|
||||||
|
session TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
))
|
||||||
|
.execute(&mut conn)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// private utility function because sqlite does not support
|
||||||
|
// parametrized table names
|
||||||
|
fn substitute_table_name(&self, query: &str) -> String {
|
||||||
|
query.replace("%%TABLE_NAME%%", &self.table_name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// retrieve a connection from the pool
|
||||||
|
async fn connection(&self) -> sqlx::Result<PoolConnection<Sqlite>> {
|
||||||
|
self.client.acquire().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Performs a one-time cleanup task that clears out stale
|
||||||
|
/// (expired) sessions. You may want to call this from cron.
|
||||||
|
/// ```rust
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
||||||
|
/// store.migrate().await?;
|
||||||
|
/// let mut session = Session::new();
|
||||||
|
/// session.set_expiry(Utc::now() - Duration::seconds(5));
|
||||||
|
/// store.store_session(session).await?;
|
||||||
|
/// assert_eq!(store.count().await?, 1);
|
||||||
|
/// store.cleanup().await?;
|
||||||
|
/// assert_eq!(store.count().await?, 0);
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub async fn cleanup(&self) -> sqlx::Result<()> {
|
||||||
|
let mut connection = self.connection().await?;
|
||||||
|
sqlx::query(&self.substitute_table_name(
|
||||||
|
r#"
|
||||||
|
DELETE FROM %%TABLE_NAME%%
|
||||||
|
WHERE expires < ?
|
||||||
|
"#,
|
||||||
|
))
|
||||||
|
.bind(Utc::now().timestamp())
|
||||||
|
.execute(&mut connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// retrieves the number of sessions currently stored, including
|
||||||
|
/// expired sessions
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use witch_watch::session_store::SqliteSessionStore;
|
||||||
|
/// # use async_session::{Result, SessionStore, Session};
|
||||||
|
/// # use std::time::Duration;
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> Result {
|
||||||
|
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
||||||
|
/// store.migrate().await?;
|
||||||
|
/// assert_eq!(store.count().await?, 0);
|
||||||
|
/// store.store_session(Session::new()).await?;
|
||||||
|
/// assert_eq!(store.count().await?, 1);
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub async fn count(&self) -> sqlx::Result<i32> {
|
||||||
|
let (count,) =
|
||||||
|
sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
|
||||||
|
.fetch_one(&mut self.connection().await?)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl SessionStore for SqliteSessionStore {
|
||||||
|
async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
|
||||||
|
let id = Session::id_from_cookie_value(&cookie_value)?;
|
||||||
|
let mut connection = self.connection().await?;
|
||||||
|
|
||||||
|
let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
|
||||||
|
r#"
|
||||||
|
SELECT session FROM %%TABLE_NAME%%
|
||||||
|
WHERE id = ? AND (expires IS NULL OR expires > ?)
|
||||||
|
"#,
|
||||||
|
))
|
||||||
|
.bind(&id)
|
||||||
|
.bind(Utc::now().timestamp())
|
||||||
|
.fetch_optional(&mut connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(result
|
||||||
|
.map(|(session,)| serde_json::from_str(&session))
|
||||||
|
.transpose()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store_session(&self, session: Session) -> Result<Option<String>> {
|
||||||
|
let id = session.id();
|
||||||
|
let string = serde_json::to_string(&session)?;
|
||||||
|
let mut connection = self.connection().await?;
|
||||||
|
|
||||||
|
sqlx::query(&self.substitute_table_name(
|
||||||
|
r#"
|
||||||
|
INSERT INTO %%TABLE_NAME%%
|
||||||
|
(id, session, expires) VALUES (?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
expires = excluded.expires,
|
||||||
|
session = excluded.session
|
||||||
|
"#,
|
||||||
|
))
|
||||||
|
.bind(id)
|
||||||
|
.bind(&string)
|
||||||
|
.bind(session.expiry().map(|expiry| expiry.timestamp()))
|
||||||
|
.execute(&mut connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(session.into_cookie_value())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn destroy_session(&self, session: Session) -> Result {
|
||||||
|
let id = session.id();
|
||||||
|
let mut connection = self.connection().await?;
|
||||||
|
sqlx::query(&self.substitute_table_name(
|
||||||
|
r#"
|
||||||
|
DELETE FROM %%TABLE_NAME%% WHERE id = ?
|
||||||
|
"#,
|
||||||
|
))
|
||||||
|
.bind(id)
|
||||||
|
.execute(&mut connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn clear_store(&self) -> Result {
|
||||||
|
let mut connection = self.connection().await?;
|
||||||
|
sqlx::query(&self.substitute_table_name(
|
||||||
|
r#"
|
||||||
|
DELETE FROM %%TABLE_NAME%%
|
||||||
|
"#,
|
||||||
|
))
|
||||||
|
.execute(&mut connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
async fn test_store() -> SqliteSessionStore {
|
||||||
|
let store = SqliteSessionStore::new("sqlite::memory:")
|
||||||
|
.await
|
||||||
|
.expect("building a sqlite :memory: SqliteSessionStore");
|
||||||
|
store
|
||||||
|
.migrate()
|
||||||
|
.await
|
||||||
|
.expect("migrating a brand new :memory: SqliteSessionStore");
|
||||||
|
store
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn creating_a_new_session_with_no_expiry() -> Result {
|
||||||
|
let store = test_store().await;
|
||||||
|
let mut session = Session::new();
|
||||||
|
session.insert("key", "value")?;
|
||||||
|
let cloned = session.clone();
|
||||||
|
let cookie_value = store.store_session(session).await?.unwrap();
|
||||||
|
|
||||||
|
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
|
||||||
|
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
|
||||||
|
.fetch_one(&mut store.connection().await?)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(1, count);
|
||||||
|
assert_eq!(id, cloned.id());
|
||||||
|
assert_eq!(expires, None);
|
||||||
|
|
||||||
|
let deserialized_session: Session = serde_json::from_str(&serialized)?;
|
||||||
|
assert_eq!(cloned.id(), deserialized_session.id());
|
||||||
|
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
|
||||||
|
|
||||||
|
let loaded_session = store.load_session(cookie_value).await?.unwrap();
|
||||||
|
assert_eq!(cloned.id(), loaded_session.id());
|
||||||
|
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
|
||||||
|
|
||||||
|
assert!(!loaded_session.is_expired());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn updating_a_session() -> Result {
|
||||||
|
let store = test_store().await;
|
||||||
|
let mut session = Session::new();
|
||||||
|
let original_id = session.id().to_owned();
|
||||||
|
|
||||||
|
session.insert("key", "value")?;
|
||||||
|
let cookie_value = store.store_session(session).await?.unwrap();
|
||||||
|
|
||||||
|
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
|
||||||
|
session.insert("key", "other value")?;
|
||||||
|
assert_eq!(None, store.store_session(session).await?);
|
||||||
|
|
||||||
|
let session = store.load_session(cookie_value.clone()).await?.unwrap();
|
||||||
|
assert_eq!(session.get::<String>("key").unwrap(), "other value");
|
||||||
|
|
||||||
|
let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions")
|
||||||
|
.fetch_one(&mut store.connection().await?)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(1, count);
|
||||||
|
assert_eq!(original_id, id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn updating_a_session_extending_expiry() -> Result {
|
||||||
|
let store = test_store().await;
|
||||||
|
let mut session = Session::new();
|
||||||
|
session.expire_in(Duration::from_secs(10));
|
||||||
|
let original_id = session.id().to_owned();
|
||||||
|
let original_expires = session.expiry().unwrap().clone();
|
||||||
|
let cookie_value = store.store_session(session).await?.unwrap();
|
||||||
|
|
||||||
|
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
|
||||||
|
assert_eq!(session.expiry().unwrap(), &original_expires);
|
||||||
|
session.expire_in(Duration::from_secs(20));
|
||||||
|
let new_expires = session.expiry().unwrap().clone();
|
||||||
|
store.store_session(session).await?;
|
||||||
|
|
||||||
|
let session = store.load_session(cookie_value.clone()).await?.unwrap();
|
||||||
|
assert_eq!(session.expiry().unwrap(), &new_expires);
|
||||||
|
|
||||||
|
let (id, expires, count): (String, i64, i64) =
|
||||||
|
sqlx::query_as("select id, expires, count(*) from async_sessions")
|
||||||
|
.fetch_one(&mut store.connection().await?)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(1, count);
|
||||||
|
assert_eq!(expires, new_expires.timestamp());
|
||||||
|
assert_eq!(original_id, id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn creating_a_new_session_with_expiry() -> Result {
|
||||||
|
let store = test_store().await;
|
||||||
|
let mut session = Session::new();
|
||||||
|
session.expire_in(Duration::from_secs(1));
|
||||||
|
session.insert("key", "value")?;
|
||||||
|
let cloned = session.clone();
|
||||||
|
|
||||||
|
let cookie_value = store.store_session(session).await?.unwrap();
|
||||||
|
|
||||||
|
let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
|
||||||
|
sqlx::query_as("select id, expires, session, count(*) from async_sessions")
|
||||||
|
.fetch_one(&mut store.connection().await?)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(1, count);
|
||||||
|
assert_eq!(id, cloned.id());
|
||||||
|
assert!(expires.unwrap() > Utc::now().timestamp());
|
||||||
|
|
||||||
|
let deserialized_session: Session = serde_json::from_str(&serialized)?;
|
||||||
|
assert_eq!(cloned.id(), deserialized_session.id());
|
||||||
|
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
|
||||||
|
|
||||||
|
let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
|
||||||
|
assert_eq!(cloned.id(), loaded_session.id());
|
||||||
|
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
|
||||||
|
|
||||||
|
assert!(!loaded_session.is_expired());
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
assert_eq!(None, store.load_session(cookie_value).await?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn destroying_a_single_session() -> Result {
|
||||||
|
let store = test_store().await;
|
||||||
|
for _ in 0..3i8 {
|
||||||
|
store.store_session(Session::new()).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cookie = store.store_session(Session::new()).await?.unwrap();
|
||||||
|
assert_eq!(4, store.count().await?);
|
||||||
|
let session = store.load_session(cookie.clone()).await?.unwrap();
|
||||||
|
store.destroy_session(session.clone()).await.unwrap();
|
||||||
|
assert_eq!(None, store.load_session(cookie).await?);
|
||||||
|
assert_eq!(3, store.count().await?);
|
||||||
|
|
||||||
|
// // attempting to destroy the session again is not an error
|
||||||
|
// assert!(store.destroy_session(session).await.is_ok());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn clearing_the_whole_store() -> Result {
|
||||||
|
let store = test_store().await;
|
||||||
|
for _ in 0..3i8 {
|
||||||
|
store.store_session(Session::new()).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(3, store.count().await?);
|
||||||
|
store.clear_store().await.unwrap();
|
||||||
|
assert_eq!(0, store.count().await?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,220 @@
|
||||||
|
use argon2::{
|
||||||
|
password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
|
||||||
|
Argon2,
|
||||||
|
};
|
||||||
|
use askama::Template;
|
||||||
|
use axum::{
|
||||||
|
extract::{Form, Path, State},
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use sqlx::{query_as, SqlitePool};
|
||||||
|
use unicode_segmentation::UnicodeSegmentation;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{templates::CreateUser, User};
|
||||||
|
|
||||||
|
const CREATE_QUERY: &str =
|
||||||
|
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
|
||||||
|
const ID_QUERY: &str = "select * from witches where id = $1";
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// Result types for user creation
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Template)]
|
||||||
|
#[template(path = "signup_success.html")]
|
||||||
|
pub struct CreateUserSuccess(User);
|
||||||
|
|
||||||
|
#[Error(desc = "Could not create user.")]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub struct CreateUserError(#[from] CreateUserErrorKind);
|
||||||
|
|
||||||
|
impl IntoResponse for CreateUserError {
|
||||||
|
fn into_response(self) -> askama_axum::Response {
|
||||||
|
match self.0 {
|
||||||
|
CreateUserErrorKind::UnknownDBError => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, format!("{self}")).into_response()
|
||||||
|
}
|
||||||
|
_ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[Error]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum CreateUserErrorKind {
|
||||||
|
AlreadyExists,
|
||||||
|
#[error(desc = "Usernames must be between 1 and 20 non-whitespace characters long")]
|
||||||
|
BadUsername,
|
||||||
|
PasswordMismatch,
|
||||||
|
#[error(desc = "Password must have at least 4 and at most 50 characters")]
|
||||||
|
BadPassword,
|
||||||
|
#[error(desc = "Display name must be less than 100 characters long")]
|
||||||
|
BadDisplayname,
|
||||||
|
BadEmail,
|
||||||
|
MissingFields,
|
||||||
|
UnknownDBError,
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// User creation route handlers
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
/// Get Handler: displays the form to create a user
|
||||||
|
pub async fn get_create_user() -> CreateUser {
|
||||||
|
CreateUser::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Post Handler: validates form values and calls the actual, private user
|
||||||
|
/// creation function
|
||||||
|
#[axum::debug_handler]
|
||||||
|
pub async fn post_create_user(
|
||||||
|
State(pool): State<SqlitePool>,
|
||||||
|
Form(signup): Form<CreateUser>,
|
||||||
|
) -> Result<impl IntoResponse, CreateUserError> {
|
||||||
|
let username = &signup.username;
|
||||||
|
let displayname = &signup.displayname;
|
||||||
|
let email = &signup.email;
|
||||||
|
let password = &signup.password;
|
||||||
|
let verify = &signup.pw_verify;
|
||||||
|
let username = username.trim();
|
||||||
|
|
||||||
|
let name_len = username.graphemes(true).size_hint().1.unwrap();
|
||||||
|
// we are not ascii exclusivists around here
|
||||||
|
if !(1..=20).contains(&name_len) {
|
||||||
|
return Err(CreateUserErrorKind::BadUsername.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
if password != verify {
|
||||||
|
return Err(CreateUserErrorKind::PasswordMismatch.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let password = urlencoding::decode(password)
|
||||||
|
.map_err(|_| CreateUserErrorKind::BadPassword)?
|
||||||
|
.to_string();
|
||||||
|
let password = password.trim();
|
||||||
|
let password = password.as_bytes();
|
||||||
|
if !(4..=50).contains(&password.len()) {
|
||||||
|
return Err(CreateUserErrorKind::BadPassword.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let displayname = if let Some(dn) = displayname {
|
||||||
|
let dn = urlencoding::decode(dn)
|
||||||
|
.map_err(|_| CreateUserErrorKind::BadDisplayname)?
|
||||||
|
.to_string()
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
if dn.graphemes(true).size_hint().1.unwrap() > 100 {
|
||||||
|
return Err(CreateUserErrorKind::BadDisplayname.into());
|
||||||
|
}
|
||||||
|
Some(dn)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let displayname = &displayname;
|
||||||
|
|
||||||
|
// TODO(2023-05-17): validate email
|
||||||
|
let email = if let Some(email) = email {
|
||||||
|
let email = urlencoding::decode(email)
|
||||||
|
.map_err(|_| CreateUserErrorKind::BadEmail)?
|
||||||
|
.to_string();
|
||||||
|
Some(email)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let email = &email;
|
||||||
|
|
||||||
|
let user = create_user(username, displayname, email, password, &pool).await?;
|
||||||
|
tracing::debug!("created {user:?}");
|
||||||
|
let id = user.id.as_simple().to_string();
|
||||||
|
let location = format!("/signup_success/{id}");
|
||||||
|
|
||||||
|
let resp = axum::response::Redirect::temporary(&location);
|
||||||
|
|
||||||
|
Ok(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generic handler for successful signup
|
||||||
|
pub async fn handle_signup_success(
|
||||||
|
Path(id): Path<String>,
|
||||||
|
State(pool): State<SqlitePool>,
|
||||||
|
) -> Response {
|
||||||
|
let id = id.trim();
|
||||||
|
let user: User = {
|
||||||
|
let id = Uuid::try_parse(id).unwrap_or_default();
|
||||||
|
query_as(ID_QUERY)
|
||||||
|
.bind(id)
|
||||||
|
.fetch_one(&pool)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut resp = CreateUserSuccess(user.clone()).into_response();
|
||||||
|
|
||||||
|
if user.username.is_empty() || id.is_empty() {
|
||||||
|
// redirect to front page if we got here without a valid witch ID
|
||||||
|
*resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;
|
||||||
|
resp.headers_mut().insert("Location", "/".parse().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// private fns
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
async fn create_user(
|
||||||
|
username: &str,
|
||||||
|
displayname: &Option<String>,
|
||||||
|
email: &Option<String>,
|
||||||
|
password: &[u8],
|
||||||
|
pool: &SqlitePool,
|
||||||
|
) -> Result<User, CreateUserError> {
|
||||||
|
// Argon2 with default params (Argon2id v19)
|
||||||
|
let argon2 = Argon2::default();
|
||||||
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
|
let pwhash = argon2
|
||||||
|
.hash_password(password, &salt)
|
||||||
|
.unwrap() // safe to unwrap, we know the salt is valid
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
let res = sqlx::query(CREATE_QUERY)
|
||||||
|
.bind(id)
|
||||||
|
.bind(username)
|
||||||
|
.bind(displayname)
|
||||||
|
.bind(email)
|
||||||
|
.bind(&pwhash)
|
||||||
|
.execute(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match res {
|
||||||
|
Ok(_) => {
|
||||||
|
let user = User {
|
||||||
|
id,
|
||||||
|
username: username.to_string(),
|
||||||
|
displayname: displayname.to_owned(),
|
||||||
|
email: email.to_owned(),
|
||||||
|
last_seen: None,
|
||||||
|
pwhash,
|
||||||
|
};
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
Err(sqlx::Error::Database(db)) => {
|
||||||
|
if let Some(exit) = db.code() {
|
||||||
|
let exit = exit.parse().unwrap_or(0u32);
|
||||||
|
// https://www.sqlite.org/rescode.html codes for unique constraint violations:
|
||||||
|
if exit == 2067u32 || exit == 1555 {
|
||||||
|
Err(CreateUserErrorKind::AlreadyExists.into())
|
||||||
|
} else {
|
||||||
|
Err(CreateUserErrorKind::UnknownDBError.into())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(CreateUserErrorKind::UnknownDBError.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(CreateUserErrorKind::UnknownDBError.into()),
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
use askama::Template;
|
use askama::Template;
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
||||||
#[template(path = "signup.html")]
|
#[template(path = "signup.html")]
|
||||||
pub struct CreateUser {
|
pub struct CreateUser {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
|
@ -10,3 +10,17 @@ pub struct CreateUser {
|
||||||
pub password: String,
|
pub password: String,
|
||||||
pub pw_verify: String,
|
pub pw_verify: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
||||||
|
#[template(path = "login_post.html")]
|
||||||
|
pub struct LoginPost {
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
||||||
|
#[template(path = "login_get.html")]
|
||||||
|
pub struct LoginGet {
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
}
|
||||||
|
|
243
src/users.rs
243
src/users.rs
|
@ -1,33 +1,19 @@
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
use argon2::{
|
use axum_login::{secrecy::SecretVec, AuthUser};
|
||||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
use sqlx::SqlitePool;
|
||||||
Argon2,
|
|
||||||
};
|
|
||||||
use askama::Template;
|
|
||||||
use axum::{
|
|
||||||
extract::{Form, Path, State},
|
|
||||||
http::StatusCode,
|
|
||||||
response::{IntoResponse, Response},
|
|
||||||
};
|
|
||||||
use sqlx::{sqlite::SqliteRow, Row, SqlitePool};
|
|
||||||
use unicode_segmentation::UnicodeSegmentation;
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::templates::CreateUser;
|
const USERNAME_QUERY: &str = "select * from witches where username = $1";
|
||||||
|
|
||||||
const CREATE_QUERY: &str =
|
#[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow)]
|
||||||
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
|
|
||||||
|
|
||||||
const ID_QUERY: &str = "select * from witches where id = $1";
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
|
||||||
pub struct User {
|
pub struct User {
|
||||||
id: Uuid,
|
pub id: Uuid,
|
||||||
username: String,
|
pub username: String,
|
||||||
displayname: Option<String>,
|
pub displayname: Option<String>,
|
||||||
email: Option<String>,
|
pub email: Option<String>,
|
||||||
last_seen: Option<i64>,
|
pub last_seen: Option<i64>,
|
||||||
|
pub(crate) pwhash: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for User {
|
impl Display for User {
|
||||||
|
@ -43,208 +29,21 @@ impl Display for User {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Template)]
|
impl AuthUser<Uuid> for User {
|
||||||
#[template(path = "signup_success.html")]
|
fn get_id(&self) -> Uuid {
|
||||||
pub struct CreateUserSuccess(User);
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
impl sqlx::FromRow<'_, SqliteRow> for User {
|
fn get_password_hash(&self) -> SecretVec<u8> {
|
||||||
fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> {
|
SecretVec::new(self.pwhash.as_bytes().to_vec())
|
||||||
let bytes: Vec<u8> = row.get("id");
|
|
||||||
let bytes = bytes.as_slice();
|
|
||||||
let bytes: [u8; 16] = bytes.try_into().unwrap();
|
|
||||||
let id = Uuid::from_bytes_le(bytes);
|
|
||||||
let username: String = row.get("username");
|
|
||||||
let displayname: Option<String> = row.get("displayname");
|
|
||||||
let last_seen: Option<i64> = row.get("last_seen");
|
|
||||||
let email: Option<String> = row.get("email");
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
id,
|
|
||||||
username,
|
|
||||||
displayname,
|
|
||||||
email,
|
|
||||||
last_seen,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get Handler: displays the form to create a user
|
impl User {
|
||||||
pub async fn get_create_user() -> CreateUser {
|
pub async fn get(username: &str, db: &SqlitePool) -> Result<User, impl std::error::Error> {
|
||||||
CreateUser::default()
|
sqlx::query_as(USERNAME_QUERY)
|
||||||
}
|
.bind(username)
|
||||||
|
.fetch_one(db)
|
||||||
/// Post Handler: validates form values and calls the actual, private user
|
|
||||||
/// creation function
|
|
||||||
#[axum::debug_handler]
|
|
||||||
pub async fn post_create_user(
|
|
||||||
State(pool): State<SqlitePool>,
|
|
||||||
Form(signup): Form<CreateUser>,
|
|
||||||
) -> Result<Response, CreateUserError> {
|
|
||||||
let username = &signup.username;
|
|
||||||
let displayname = &signup.displayname;
|
|
||||||
let email = &signup.email;
|
|
||||||
let password = &signup.password;
|
|
||||||
let verify = &signup.pw_verify;
|
|
||||||
let username = username.trim();
|
|
||||||
|
|
||||||
let name_len = username.graphemes(true).size_hint().1.unwrap();
|
|
||||||
// we are not ascii exclusivists around here
|
|
||||||
if !(1..=20).contains(&name_len) {
|
|
||||||
return Err(CreateUserErrorKind::BadUsername.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref dn) = displayname {
|
|
||||||
if dn.len() > 50 {
|
|
||||||
return Err(CreateUserErrorKind::BadDisplayname.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if password != verify {
|
|
||||||
return Err(CreateUserErrorKind::PasswordMismatch.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let password = urlencoding::decode(password)
|
|
||||||
.map_err(|_| CreateUserErrorKind::BadPassword)?
|
|
||||||
.to_string();
|
|
||||||
let password = password.as_bytes();
|
|
||||||
|
|
||||||
let displayname = if let Some(dn) = displayname {
|
|
||||||
let dn = urlencoding::decode(dn)
|
|
||||||
.map_err(|_| CreateUserErrorKind::BadDisplayname)?
|
|
||||||
.to_string();
|
|
||||||
Some(dn)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let displayname = &displayname;
|
|
||||||
|
|
||||||
// TODO(2023-05-17): validate email
|
|
||||||
let email = if let Some(email) = email {
|
|
||||||
let email = urlencoding::decode(email)
|
|
||||||
.map_err(|_| CreateUserErrorKind::BadEmail)?
|
|
||||||
.to_string();
|
|
||||||
Some(email)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let email = &email;
|
|
||||||
|
|
||||||
let user = create_user(username, displayname, email, password, &pool).await?;
|
|
||||||
tracing::debug!("created {user:?}");
|
|
||||||
let id = user.id.simple().to_string();
|
|
||||||
let location = format!("/signup_success/{id}");
|
|
||||||
|
|
||||||
let resp = axum::response::Redirect::temporary(&location).into_response();
|
|
||||||
|
|
||||||
Ok(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get handler for successful signup
|
|
||||||
pub async fn handle_signup_success(
|
|
||||||
Path(id): Path<String>,
|
|
||||||
State(pool): State<SqlitePool>,
|
|
||||||
) -> Response {
|
|
||||||
let user: User = {
|
|
||||||
let id = id;
|
|
||||||
let id = Uuid::try_parse(&id).unwrap_or_default();
|
|
||||||
let id_bytes = id.to_bytes_le();
|
|
||||||
sqlx::query_as(ID_QUERY)
|
|
||||||
.bind(id_bytes.as_slice())
|
|
||||||
.fetch_one(&pool)
|
|
||||||
.await
|
.await
|
||||||
.unwrap_or_default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut resp = CreateUserSuccess(user.clone()).into_response();
|
|
||||||
|
|
||||||
if user.username.is_empty() {
|
|
||||||
// redirect to front page if we got here without a valid witch header
|
|
||||||
*resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;
|
|
||||||
resp.headers_mut().insert("Location", "/".parse().unwrap());
|
|
||||||
}
|
|
||||||
resp
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_user(
|
|
||||||
username: &str,
|
|
||||||
displayname: &Option<String>,
|
|
||||||
email: &Option<String>,
|
|
||||||
password: &[u8],
|
|
||||||
pool: &SqlitePool,
|
|
||||||
) -> Result<User, CreateUserError> {
|
|
||||||
// Argon2 with default params (Argon2id v19)
|
|
||||||
let argon2 = Argon2::default();
|
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
|
||||||
let pwhash = argon2
|
|
||||||
.hash_password(password, &salt)
|
|
||||||
.unwrap() // safe to unwrap, we know the salt is valid
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let id = Uuid::new_v4();
|
|
||||||
let id_bytes = id.to_bytes_le();
|
|
||||||
let id_bytes = id_bytes.as_slice();
|
|
||||||
let res = sqlx::query(CREATE_QUERY)
|
|
||||||
.bind(id_bytes)
|
|
||||||
.bind(username)
|
|
||||||
.bind(displayname)
|
|
||||||
.bind(email)
|
|
||||||
.bind(pwhash)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match res {
|
|
||||||
Ok(_) => {
|
|
||||||
let user = User {
|
|
||||||
id,
|
|
||||||
username: username.to_string(),
|
|
||||||
displayname: displayname.to_owned(),
|
|
||||||
email: email.to_owned(),
|
|
||||||
last_seen: None,
|
|
||||||
};
|
|
||||||
Ok(user)
|
|
||||||
}
|
|
||||||
Err(sqlx::Error::Database(db)) => {
|
|
||||||
if let Some(exit) = db.code() {
|
|
||||||
let exit = exit.parse().unwrap_or(0u32);
|
|
||||||
// https://www.sqlite.org/rescode.html codes for unique constraint violations:
|
|
||||||
if exit == 2067u32 || exit == 1555 {
|
|
||||||
Err(CreateUserErrorKind::AlreadyExists.into())
|
|
||||||
} else {
|
|
||||||
Err(CreateUserErrorKind::UnknownDBError.into())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(CreateUserErrorKind::UnknownDBError.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => Err(CreateUserErrorKind::UnknownDBError.into()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[Error(desc = "Could not create user.")]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub struct CreateUserError(#[from] CreateUserErrorKind);
|
|
||||||
|
|
||||||
impl IntoResponse for CreateUserError {
|
|
||||||
fn into_response(self) -> askama_axum::Response {
|
|
||||||
match self.0 {
|
|
||||||
CreateUserErrorKind::UnknownDBError => {
|
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{self}")).into_response()
|
|
||||||
}
|
|
||||||
_ => (StatusCode::BAD_REQUEST, format!("{self}")).into_response(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[Error]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum CreateUserErrorKind {
|
|
||||||
AlreadyExists,
|
|
||||||
#[error(desc = "Usernames must be between 1 and 20 non-whitespace characters long")]
|
|
||||||
BadUsername,
|
|
||||||
PasswordMismatch,
|
|
||||||
BadPassword,
|
|
||||||
BadDisplayname,
|
|
||||||
BadEmail,
|
|
||||||
MissingFields,
|
|
||||||
UnknownDBError,
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
pub fn form_decode<E: std::error::Error>(input: &str, err: E) -> Result<String, E> {
|
||||||
|
Ok(urlencoding::decode(input).map_err(|_| err)?.into_owned())
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}Login to Witch Watch, Bish{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<p>
|
||||||
|
<form action="/login" enctype="application/x-www-form-urlencoded" method="post">
|
||||||
|
<label for="username">Username</label>
|
||||||
|
<input type="text" name="username" id="username" minlength="1" maxlength="20" required></br>
|
||||||
|
<label for="password">Password</label>
|
||||||
|
<input type="password" name="password" id="password" required></br>
|
||||||
|
<input type="submit" value="Signup">
|
||||||
|
</form>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
{% endblock %}
|
Loading…
Reference in New Issue