Merge branch 'tests'
This commit is contained in:
commit
613814a648
7 changed files with 258 additions and 136 deletions
33
Cargo.lock
generated
33
Cargo.lock
generated
|
@ -158,6 +158,12 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "auto-future"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3c1e7e457ea78e524f48639f551fd79703ac3f2237f5ecccdf4708f8a75ad373"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
@ -290,6 +296,23 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-test"
|
||||||
|
version = "9.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "75a01b0885dcea3124d990b24fd5ad9d0735f77d53385161665a0d9fd08a03e6"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"auto-future",
|
||||||
|
"axum",
|
||||||
|
"cookie",
|
||||||
|
"hyper",
|
||||||
|
"portpicker",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "base64"
|
name = "base64"
|
||||||
version = "0.13.1"
|
version = "0.13.1"
|
||||||
|
@ -1258,6 +1281,15 @@ version = "0.3.27"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
|
checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "portpicker"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9"
|
||||||
|
dependencies = [
|
||||||
|
"rand",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
version = "0.2.17"
|
version = "0.2.17"
|
||||||
|
@ -2345,6 +2377,7 @@ dependencies = [
|
||||||
"axum",
|
"axum",
|
||||||
"axum-login",
|
"axum-login",
|
||||||
"axum-macros",
|
"axum-macros",
|
||||||
|
"axum-test",
|
||||||
"justerror",
|
"justerror",
|
||||||
"password-hash",
|
"password-hash",
|
||||||
"rand_core",
|
"rand_core",
|
||||||
|
|
|
@ -26,3 +26,6 @@ unicode-segmentation = "1"
|
||||||
urlencoding = "2"
|
urlencoding = "2"
|
||||||
async-session = "3"
|
async-session = "3"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
axum-test = "9.0.0"
|
||||||
|
|
||||||
|
|
176
src/db.rs
176
src/db.rs
|
@ -7,6 +7,7 @@ use axum_login::{
|
||||||
};
|
};
|
||||||
use session_store::SqliteSessionStore;
|
use session_store::SqliteSessionStore;
|
||||||
use sqlx::{
|
use sqlx::{
|
||||||
|
migrate::Migrator,
|
||||||
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
|
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
|
||||||
SqlitePool,
|
SqlitePool,
|
||||||
};
|
};
|
||||||
|
@ -15,30 +16,62 @@ use uuid::Uuid;
|
||||||
use crate::User;
|
use crate::User;
|
||||||
|
|
||||||
const MAX_CONNS: u32 = 100;
|
const MAX_CONNS: u32 = 100;
|
||||||
const TIMEOUT: u64 = 5;
|
const MIN_CONNS: u32 = 10;
|
||||||
|
const TIMEOUT: u64 = 11;
|
||||||
const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64);
|
const SESSION_TTL: Duration = Duration::from_secs((365.2422 * 24. * 3600.0) as u64);
|
||||||
|
|
||||||
pub async fn get_pool() -> SqlitePool {
|
pub async fn get_pool() -> SqlitePool {
|
||||||
let db_filename = {
|
let db_filename = {
|
||||||
std::env::var("DATABASE_FILE").unwrap_or_else(|_| {
|
std::env::var("DATABASE_FILE").unwrap_or_else(|_| {
|
||||||
|
#[cfg(not(test))]
|
||||||
|
{
|
||||||
let home =
|
let home =
|
||||||
std::env::var("HOME").expect("Could not determine $HOME for finding db file");
|
std::env::var("HOME").expect("Could not determine $HOME for finding db file");
|
||||||
format!("{home}/.witch-watch.db")
|
format!("{home}/.witch-watch.db")
|
||||||
|
}
|
||||||
|
#[cfg(test)]
|
||||||
|
{
|
||||||
|
use rand_core::RngCore;
|
||||||
|
let mut rng = rand_core::OsRng;
|
||||||
|
let id = rng.next_u64();
|
||||||
|
format!("file:testdb-{id}?mode=memory&cache=shared")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::info!("Connecting to DB at {db_filename}");
|
||||||
|
|
||||||
let conn_opts = SqliteConnectOptions::new()
|
let conn_opts = SqliteConnectOptions::new()
|
||||||
.foreign_keys(true)
|
.foreign_keys(true)
|
||||||
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
|
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
|
||||||
.filename(&db_filename)
|
.filename(&db_filename)
|
||||||
.busy_timeout(Duration::from_secs(TIMEOUT));
|
.busy_timeout(Duration::from_secs(TIMEOUT))
|
||||||
|
.create_if_missing(true);
|
||||||
|
|
||||||
// setup connection pool
|
let pool = SqlitePoolOptions::new()
|
||||||
SqlitePoolOptions::new()
|
|
||||||
.max_connections(MAX_CONNS)
|
.max_connections(MAX_CONNS)
|
||||||
|
.min_connections(MIN_CONNS)
|
||||||
|
.idle_timeout(Some(Duration::from_secs(10)))
|
||||||
|
.max_lifetime(Some(Duration::from_secs(3600)))
|
||||||
.connect_with(conn_opts)
|
.connect_with(conn_opts)
|
||||||
.await
|
.await
|
||||||
.expect("can't connect to database")
|
.expect("can't connect to database");
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut m = Migrator::new(std::path::Path::new("./migrations"))
|
||||||
|
.await
|
||||||
|
.expect("Should be able to read the migration directory.");
|
||||||
|
|
||||||
|
let m = m.set_locking(true);
|
||||||
|
|
||||||
|
m.run(&pool)
|
||||||
|
.await
|
||||||
|
.expect("Should be able to run the migration.");
|
||||||
|
|
||||||
|
tracing::info!("Ran migrations");
|
||||||
|
}
|
||||||
|
|
||||||
|
pool
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn session_layer(pool: SqlitePool, secret: &[u8]) -> SessionLayer<SqliteSessionStore> {
|
pub async fn session_layer(pool: SqlitePool, secret: &[u8]) -> SessionLayer<SqliteSessionStore> {
|
||||||
|
@ -71,6 +104,26 @@ pub async fn auth_layer(
|
||||||
AuthLayer::new(store, secret)
|
AuthLayer::new(store, secret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// Tests for `db` module.
|
||||||
|
//-************************************************************************
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn it_migrates_the_db() {
|
||||||
|
let db = super::get_pool().await;
|
||||||
|
let r = sqlx::query("select count(*) from witches")
|
||||||
|
.fetch_one(&db)
|
||||||
|
.await;
|
||||||
|
assert!(r.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// End public interface.
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
// Session store sub-module, not a public lib.
|
// Session store sub-module, not a public lib.
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
|
@ -104,25 +157,6 @@ mod session_store {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/// sqlx sqlite session store for async-sessions
|
/// sqlx sqlite session store for async-sessions
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// use async_session::{Session, SessionStore, Result};
|
|
||||||
/// use std::time::Duration;
|
|
||||||
///
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
|
||||||
/// store.migrate().await?;
|
|
||||||
///
|
|
||||||
/// let mut session = Session::new();
|
|
||||||
/// session.insert("key", vec![1,2,3]);
|
|
||||||
///
|
|
||||||
/// let cookie_value = store.store_session(session).await?.unwrap();
|
|
||||||
/// let session = store.load_session(cookie_value).await?.unwrap();
|
|
||||||
/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
|
|
||||||
/// # Ok(()) }
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct SqliteSessionStore {
|
pub struct SqliteSessionStore {
|
||||||
client: SqlitePool,
|
client: SqlitePool,
|
||||||
|
@ -134,18 +168,6 @@ mod session_store {
|
||||||
/// sqlx::SqlitePool. the default table name for this session
|
/// sqlx::SqlitePool. the default table name for this session
|
||||||
/// store will be "async_sessions". To override this, chain this
|
/// store will be "async_sessions". To override this, chain this
|
||||||
/// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
|
/// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::Result;
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
|
|
||||||
/// let store = SqliteSessionStore::from_client(pool)
|
|
||||||
/// .with_table_name("custom_table_name");
|
|
||||||
/// store.migrate().await;
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
pub fn from_client(client: SqlitePool) -> Self {
|
pub fn from_client(client: SqlitePool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
|
@ -163,16 +185,6 @@ mod session_store {
|
||||||
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
|
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
|
||||||
/// use
|
/// use
|
||||||
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
|
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::Result;
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
|
||||||
/// store.migrate().await;
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
pub async fn new(database_url: &str) -> sqlx::Result<Self> {
|
pub async fn new(database_url: &str) -> sqlx::Result<Self> {
|
||||||
Ok(Self::from_client(SqlitePool::connect(database_url).await?))
|
Ok(Self::from_client(SqlitePool::connect(database_url).await?))
|
||||||
}
|
}
|
||||||
|
@ -183,16 +195,6 @@ mod session_store {
|
||||||
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
|
/// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
|
||||||
/// use
|
/// use
|
||||||
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
|
/// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::Result;
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?;
|
|
||||||
/// store.migrate().await;
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
pub async fn new_with_table_name(
|
pub async fn new_with_table_name(
|
||||||
database_url: &str,
|
database_url: &str,
|
||||||
table_name: &str,
|
table_name: &str,
|
||||||
|
@ -202,26 +204,6 @@ mod session_store {
|
||||||
|
|
||||||
/// Chainable method to add a custom table name. This will panic
|
/// Chainable method to add a custom table name. This will panic
|
||||||
/// if the table name is not `[a-zA-Z0-9_-]+`.
|
/// if the table name is not `[a-zA-Z0-9_-]+`.
|
||||||
/// ```rust
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::Result;
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
|
|
||||||
/// .with_table_name("custom_name");
|
|
||||||
/// store.migrate().await;
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// ```should_panic
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::Result;
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new("sqlite::memory:").await?
|
|
||||||
/// .with_table_name("johnny (); drop users;");
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
|
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
|
||||||
let table_name = table_name.as_ref();
|
let table_name = table_name.as_ref();
|
||||||
if table_name.is_empty()
|
if table_name.is_empty()
|
||||||
|
@ -244,19 +226,6 @@ mod session_store {
|
||||||
/// store initialization. In the future, this may make
|
/// store initialization. In the future, this may make
|
||||||
/// exactly-once modifications to the schema of the session table
|
/// exactly-once modifications to the schema of the session table
|
||||||
/// on breaking releases.
|
/// on breaking releases.
|
||||||
/// ```rust
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::{Result, SessionStore, Session};
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
|
||||||
/// assert!(store.count().await.is_err());
|
|
||||||
/// store.migrate().await?;
|
|
||||||
/// store.store_session(Session::new()).await?;
|
|
||||||
/// store.migrate().await?; // calling it a second time is safe
|
|
||||||
/// assert_eq!(store.count().await?, 1);
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
pub async fn migrate(&self) -> sqlx::Result<()> {
|
pub async fn migrate(&self) -> sqlx::Result<()> {
|
||||||
log::info!("migrating sessions on `{}`", self.table_name);
|
log::info!("migrating sessions on `{}`", self.table_name);
|
||||||
|
|
||||||
|
@ -288,21 +257,6 @@ mod session_store {
|
||||||
|
|
||||||
/// Performs a one-time cleanup task that clears out stale
|
/// Performs a one-time cleanup task that clears out stale
|
||||||
/// (expired) sessions. You may want to call this from cron.
|
/// (expired) sessions. You may want to call this from cron.
|
||||||
/// ```rust
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
|
||||||
/// store.migrate().await?;
|
|
||||||
/// let mut session = Session::new();
|
|
||||||
/// session.set_expiry(Utc::now() - Duration::seconds(5));
|
|
||||||
/// store.store_session(session).await?;
|
|
||||||
/// assert_eq!(store.count().await?, 1);
|
|
||||||
/// store.cleanup().await?;
|
|
||||||
/// assert_eq!(store.count().await?, 0);
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
pub async fn cleanup(&self) -> sqlx::Result<()> {
|
pub async fn cleanup(&self) -> sqlx::Result<()> {
|
||||||
let mut connection = self.connection().await?;
|
let mut connection = self.connection().await?;
|
||||||
sqlx::query(&self.substitute_table_name(
|
sqlx::query(&self.substitute_table_name(
|
||||||
|
@ -320,20 +274,6 @@ mod session_store {
|
||||||
|
|
||||||
/// retrieves the number of sessions currently stored, including
|
/// retrieves the number of sessions currently stored, including
|
||||||
/// expired sessions
|
/// expired sessions
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use witch_watch::session_store::SqliteSessionStore;
|
|
||||||
/// # use async_session::{Result, SessionStore, Session};
|
|
||||||
/// # use std::time::Duration;
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> Result {
|
|
||||||
/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
|
|
||||||
/// store.migrate().await?;
|
|
||||||
/// assert_eq!(store.count().await?, 0);
|
|
||||||
/// store.store_session(Session::new()).await?;
|
|
||||||
/// assert_eq!(store.count().await?, 1);
|
|
||||||
/// # Ok(()) }
|
|
||||||
/// ```
|
|
||||||
pub async fn count(&self) -> sqlx::Result<i32> {
|
pub async fn count(&self) -> sqlx::Result<i32> {
|
||||||
let (count,) =
|
let (count,) =
|
||||||
sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
|
sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
|
||||||
|
|
|
@ -3,7 +3,7 @@ use axum::response::{IntoResponse, Redirect};
|
||||||
use crate::{templates::Index, AuthContext};
|
use crate::{templates::Index, AuthContext};
|
||||||
|
|
||||||
pub async fn handle_slash_redir() -> impl IntoResponse {
|
pub async fn handle_slash_redir() -> impl IntoResponse {
|
||||||
Redirect::temporary("/")
|
Redirect::to("/")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
|
pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
|
||||||
|
@ -17,3 +17,38 @@ pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
|
||||||
user: auth.current_user,
|
user: auth.current_user,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use axum_test::TestServer;
|
||||||
|
|
||||||
|
use crate::db;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn slash_is_ok() {
|
||||||
|
let pool = db::get_pool().await;
|
||||||
|
let secret = [0u8; 64];
|
||||||
|
let app = crate::app(pool.clone(), &secret).await.into_make_service();
|
||||||
|
|
||||||
|
let server = TestServer::new(app).unwrap();
|
||||||
|
|
||||||
|
server.get("/").await.assert_status_ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn not_found_is_303() {
|
||||||
|
let pool = db::get_pool().await;
|
||||||
|
let secret = [0u8; 64];
|
||||||
|
let app = crate::app(pool, &secret).await.into_make_service();
|
||||||
|
|
||||||
|
let server = TestServer::new(app).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
server
|
||||||
|
.get("/no-actual-route")
|
||||||
|
.expect_failure()
|
||||||
|
.await
|
||||||
|
.status_code(),
|
||||||
|
303
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
105
src/login.rs
105
src/login.rs
|
@ -39,11 +39,12 @@ pub enum LoginErrorKind {
|
||||||
impl IntoResponse for LoginError {
|
impl IntoResponse for LoginError {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
match self.0 {
|
match self.0 {
|
||||||
LoginErrorKind::Unknown | LoginErrorKind::Internal => (
|
LoginErrorKind::Internal => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"An unknown error occurred; you cursed, brah?",
|
"An unknown error occurred; you cursed, brah?",
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
|
LoginErrorKind::Unknown => (StatusCode::OK, "Not successful.").into_response(),
|
||||||
_ => (StatusCode::OK, format!("{self}")).into_response(),
|
_ => (StatusCode::OK, format!("{self}")).into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -79,7 +80,7 @@ pub async fn post_login(
|
||||||
.await
|
.await
|
||||||
.map_err(|_| LoginErrorKind::Internal)?;
|
.map_err(|_| LoginErrorKind::Internal)?;
|
||||||
|
|
||||||
Ok(Redirect::temporary("/"))
|
Ok(Redirect::to("/"))
|
||||||
}
|
}
|
||||||
_ => Err(LoginErrorKind::BadPassword.into()),
|
_ => Err(LoginErrorKind::BadPassword.into()),
|
||||||
}
|
}
|
||||||
|
@ -99,3 +100,103 @@ pub async fn post_logout(mut auth: AuthContext) -> impl IntoResponse {
|
||||||
}
|
}
|
||||||
LogoutPost
|
LogoutPost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use axum::body::Bytes;
|
||||||
|
use axum_test::TestServer;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
db,
|
||||||
|
signup::create_user,
|
||||||
|
templates::{LoginGet, LogoutGet, LogoutPost},
|
||||||
|
};
|
||||||
|
|
||||||
|
async fn tserver() -> TestServer {
|
||||||
|
let pool = db::get_pool().await;
|
||||||
|
let secret = [0u8; 64];
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||||
|
|
||||||
|
let _user = create_user(
|
||||||
|
"test_user",
|
||||||
|
&Some("Test User".to_string()),
|
||||||
|
&Some("mail@email".to_string()),
|
||||||
|
"aaaa".as_bytes(),
|
||||||
|
&pool,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let r = sqlx::query("select count(*) from witches")
|
||||||
|
.fetch_one(&pool)
|
||||||
|
.await;
|
||||||
|
assert!(r.is_ok());
|
||||||
|
|
||||||
|
let app = crate::app(pool, &secret).await.into_make_service();
|
||||||
|
|
||||||
|
TestServer::new(app).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn get_login() {
|
||||||
|
let s = tserver().await;
|
||||||
|
let resp = s.get("/login").await;
|
||||||
|
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
|
||||||
|
assert_eq!(body, LoginGet::default().to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn post_login_success() {
|
||||||
|
let s = tserver().await;
|
||||||
|
|
||||||
|
let form = "username=test_user&password=aaaa".to_string();
|
||||||
|
let bytes = form.as_bytes();
|
||||||
|
let body = Bytes::copy_from_slice(bytes);
|
||||||
|
|
||||||
|
let resp = s
|
||||||
|
.post("/login")
|
||||||
|
.expect_failure()
|
||||||
|
.content_type("application/x-www-form-urlencoded")
|
||||||
|
.bytes(body)
|
||||||
|
.await;
|
||||||
|
assert_eq!(resp.status_code(), 303);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn post_login_bad_user() {
|
||||||
|
let s = tserver().await;
|
||||||
|
|
||||||
|
let form = "username=test_LOSER&password=aaaa".to_string();
|
||||||
|
let bytes = form.as_bytes();
|
||||||
|
let body = Bytes::copy_from_slice(bytes);
|
||||||
|
|
||||||
|
let resp = s
|
||||||
|
.post("/login")
|
||||||
|
.expect_success()
|
||||||
|
.content_type("application/x-www-form-urlencoded")
|
||||||
|
.bytes(body)
|
||||||
|
.await;
|
||||||
|
assert_eq!(resp.status_code(), 200);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn get_logout() {
|
||||||
|
let s = tserver().await;
|
||||||
|
let resp = s.get("/logout").await;
|
||||||
|
let body = std::str::from_utf8(resp.bytes()).unwrap().to_string();
|
||||||
|
assert_eq!(body, LogoutGet.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn post_logout() {
|
||||||
|
let s = tserver().await;
|
||||||
|
let resp = s.post("/logout").await;
|
||||||
|
resp.assert_status_ok();
|
||||||
|
let body = std::str::from_utf8(resp.bytes()).unwrap();
|
||||||
|
let default = LogoutPost.to_string();
|
||||||
|
assert_eq!(body, &default);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -78,8 +78,11 @@ pub async fn post_create_user(
|
||||||
let email = &signup.email;
|
let email = &signup.email;
|
||||||
let password = &signup.password;
|
let password = &signup.password;
|
||||||
let verify = &signup.pw_verify;
|
let verify = &signup.pw_verify;
|
||||||
let username = username.trim();
|
|
||||||
|
|
||||||
|
let username = urlencoding::decode(username)
|
||||||
|
.map_err(|_| CreateUserErrorKind::BadUsername)?
|
||||||
|
.to_string();
|
||||||
|
let username = username.trim();
|
||||||
let name_len = username.graphemes(true).size_hint().1.unwrap();
|
let name_len = username.graphemes(true).size_hint().1.unwrap();
|
||||||
// we are not ascii exclusivists around here
|
// we are not ascii exclusivists around here
|
||||||
if !(1..=20).contains(&name_len) {
|
if !(1..=20).contains(&name_len) {
|
||||||
|
@ -165,7 +168,7 @@ pub async fn handle_signup_success(
|
||||||
// private fns
|
// private fns
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
|
|
||||||
async fn create_user(
|
pub(crate) async fn create_user(
|
||||||
username: &str,
|
username: &str,
|
||||||
displayname: &Option<String>,
|
displayname: &Option<String>,
|
||||||
email: &Option<String>,
|
email: &Option<String>,
|
||||||
|
@ -181,14 +184,21 @@ async fn create_user(
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
let id = Uuid::new_v4();
|
let id = Uuid::new_v4();
|
||||||
let res = sqlx::query(CREATE_QUERY)
|
let query = sqlx::query(CREATE_QUERY)
|
||||||
.bind(id)
|
.bind(id)
|
||||||
.bind(username)
|
.bind(username)
|
||||||
.bind(displayname)
|
.bind(displayname)
|
||||||
.bind(email)
|
.bind(email)
|
||||||
.bind(&pwhash)
|
.bind(&pwhash);
|
||||||
.execute(pool)
|
|
||||||
.await;
|
let res = {
|
||||||
|
let txn = pool.begin().await.expect("Could not beign transaction");
|
||||||
|
let r = query.execute(pool).await;
|
||||||
|
txn.commit()
|
||||||
|
.await
|
||||||
|
.expect("Should be able to commit transaction");
|
||||||
|
r
|
||||||
|
};
|
||||||
|
|
||||||
match res {
|
match res {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
|
|
|
@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::User;
|
use crate::User;
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
|
||||||
#[template(path = "signup.html")]
|
#[template(path = "signup.html")]
|
||||||
pub struct CreateUser {
|
pub struct CreateUser {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
|
@ -13,29 +13,29 @@ pub struct CreateUser {
|
||||||
pub pw_verify: String,
|
pub pw_verify: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
|
||||||
#[template(path = "login_post.html")]
|
#[template(path = "login_post.html")]
|
||||||
pub struct LoginPost {
|
pub struct LoginPost {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
|
||||||
#[template(path = "login_get.html")]
|
#[template(path = "login_get.html")]
|
||||||
pub struct LoginGet {
|
pub struct LoginGet {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
|
||||||
#[template(path = "logout_get.html")]
|
#[template(path = "logout_get.html")]
|
||||||
pub struct LogoutGet;
|
pub struct LogoutGet;
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
|
||||||
#[template(path = "logout_post.html")]
|
#[template(path = "logout_post.html")]
|
||||||
pub struct LogoutPost;
|
pub struct LogoutPost;
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize, Serialize)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq)]
|
||||||
#[template(path = "index.html")]
|
#[template(path = "index.html")]
|
||||||
pub struct Index {
|
pub struct Index {
|
||||||
pub user: Option<User>,
|
pub user: Option<User>,
|
||||||
|
|
Loading…
Reference in a new issue