diff --git a/src/generic_handlers.rs b/src/generic_handlers.rs new file mode 100644 index 0000000..7ad4bcd --- /dev/null +++ b/src/generic_handlers.rs @@ -0,0 +1,7 @@ +use axum::response::{IntoResponse, Redirect}; + +pub async fn handle_slash_redir() -> impl IntoResponse { + Redirect::temporary("/") +} + +pub async fn handle_slash() -> impl IntoResponse {} diff --git a/src/handlers.rs b/src/handlers.rs deleted file mode 100644 index 5ba9bd4..0000000 --- a/src/handlers.rs +++ /dev/null @@ -1,55 +0,0 @@ -use axum::{ - async_trait, - extract::{FromRef, FromRequestParts, State}, - http::{request::Parts, StatusCode}, -}; -use sqlx::SqlitePool; - -pub async fn using_connection_pool_extractor( - State(pool): State, -) -> Result { - sqlx::query_scalar("select 'hello world from sqlite get'") - .fetch_one(&pool) - .await - .map_err(internal_error) -} - -// we can also write a custom extractor that grabs a connection from the pool -// which setup is appropriate depends on your application -pub struct DatabaseConnection(sqlx::pool::PoolConnection); - -#[async_trait] -impl FromRequestParts for DatabaseConnection -where - SqlitePool: FromRef, - S: Send + Sync, -{ - type Rejection = (StatusCode, String); - - async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { - let pool = SqlitePool::from_ref(state); - - let conn = pool.acquire().await.map_err(internal_error)?; - - Ok(Self(conn)) - } -} - -pub async fn using_connection_extractor( - DatabaseConnection(conn): DatabaseConnection, -) -> Result { - let mut conn = conn; - sqlx::query_scalar("select 'hello world from sqlite post'") - .fetch_one(&mut conn) - .await - .map_err(internal_error) -} - -/// Utility function for mapping any error into a `500 Internal Server Error` -/// response. -fn internal_error(err: E) -> (StatusCode, String) -where - E: std::error::Error, -{ - (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) -} diff --git a/src/lib.rs b/src/lib.rs index f97526d..bf90918 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,6 @@ extern crate justerror; pub mod db; -pub mod handlers; +pub mod generic_handlers; pub(crate) mod templates; pub mod users; diff --git a/src/main.rs b/src/main.rs index 51e79c6..f99d315 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use axum::{routing::get, Router}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use witch_watch::{ db, + generic_handlers::{handle_slash, handle_slash_redir}, users::{get_create_user, handle_signup_success, post_create_user}, }; @@ -21,11 +22,13 @@ async fn main() { // build our application with some routes let app = Router::new() + .route("/", get(handle_slash).post(handle_slash)) .route("/signup", get(get_create_user).post(post_create_user)) .route( "/signup_success/:id", get(handle_signup_success).post(handle_signup_success), ) + .fallback(handle_slash_redir) .with_state(pool); tracing::debug!("binding to 0.0.0.0:3000"); diff --git a/src/users.rs b/src/users.rs index a140739..e6f067d 100644 --- a/src/users.rs +++ b/src/users.rs @@ -145,8 +145,8 @@ pub async fn handle_signup_success( State(pool): State, ) -> Response { let user: User = { - let id = id; - let id = Uuid::try_parse(&id).unwrap_or_default(); + let id = id.trim(); + let id = Uuid::try_parse(id).unwrap_or_default(); let id_bytes = id.to_bytes_le(); sqlx::query_as(ID_QUERY) .bind(id_bytes.as_slice()) @@ -157,8 +157,8 @@ pub async fn handle_signup_success( let mut resp = CreateUserSuccess(user.clone()).into_response(); - if user.username.is_empty() { - // redirect to front page if we got here without a valid witch header + if user.username.is_empty() || id.is_empty() { + // redirect to front page if we got here without a valid witch ID *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; resp.headers_mut().insert("Location", "/".parse().unwrap()); }