diff --git a/src/login.rs b/src/login.rs index 225b216..3530840 100644 --- a/src/login.rs +++ b/src/login.rs @@ -20,8 +20,6 @@ use crate::{ // Constants //-************************************************************************ -const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1"; - //-************************************************************************ // Login error and success types //-************************************************************************ @@ -68,7 +66,7 @@ pub async fn post_login( let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?; let pw = pw.trim(); - let user = User::get(username, &pool) + let user = User::try_get(username, &pool) .await .map_err(|_| LoginErrorKind::Unknown)?; @@ -81,13 +79,6 @@ pub async fn post_login( .await .map_err(|_| LoginErrorKind::Internal)?; - // update last_seen; maybe this is ok to fail? - sqlx::query(LAST_SEEN_QUERY) - .bind(user.id) - .execute(&pool) - .await - .map_err(|_| LoginErrorKind::Internal)?; - Ok(Redirect::temporary("/")) } _ => Err(LoginErrorKind::BadPassword.into()), diff --git a/src/main.rs b/src/main.rs index a8529a9..8ec01e8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use axum::{routing::get, Router}; +use axum::{middleware, routing::get, Router}; use rand_core::{OsRng, RngCore}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use witch_watch::{ @@ -8,6 +8,7 @@ use witch_watch::{ generic_handlers::{handle_slash, handle_slash_redir}, login::{get_login, get_logout, post_login, post_logout}, signup::{get_create_user, handle_signup_success, post_create_user}, + users, }; #[tokio::main] @@ -42,6 +43,10 @@ async fn main() { .route("/login", get(get_login).post(post_login)) .route("/logout", get(get_logout).post(post_logout)) .fallback(handle_slash_redir) + .route_layer(middleware::from_fn_with_state( + pool.clone(), + users::handle_update_last_seen, + )) .layer(auth_layer) .layer(session_layer) .with_state(pool); diff --git a/src/users.rs b/src/users.rs index baeaf2f..e54f04b 100644 --- a/src/users.rs +++ b/src/users.rs @@ -1,11 +1,15 @@ use std::fmt::Display; +use axum::{extract::State, http::Request, middleware::Next, response::IntoResponse}; use axum_login::{secrecy::SecretVec, AuthUser}; use serde::{Deserialize, Serialize}; use sqlx::SqlitePool; use uuid::Uuid; +use crate::AuthContext; + const USERNAME_QUERY: &str = "select * from witches where username = $1"; +const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1"; #[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)] pub struct User { @@ -41,10 +45,40 @@ impl AuthUser for User { } impl User { - pub async fn get(username: &str, db: &SqlitePool) -> Result { + pub async fn try_get(username: &str, db: &SqlitePool) -> Result { sqlx::query_as(USERNAME_QUERY) .bind(username) .fetch_one(db) .await } + + pub async fn update_last_seen(&self, pool: &SqlitePool) { + match sqlx::query(LAST_SEEN_QUERY) + .bind(self.id) + .execute(pool) + .await + { + Ok(_) => {} + Err(e) => { + let id = self.id.as_simple(); + tracing::error!("Could not update last_seen for user {id}; got {e:?}"); + } + } + } +} + +//-************************************************************************ +// User-specific middleware +//-************************************************************************ + +pub async fn handle_update_last_seen( + State(pool): State, + auth: AuthContext, + request: Request, + next: Next, +) -> impl IntoResponse { + if let Some(user) = auth.current_user { + user.update_last_seen(&pool).await; + } + next.run(request).await }