Adds working login route.

This commit is contained in:
Joe Ardent 2023-05-28 17:55:16 -07:00
parent 559e277d9e
commit dbff72330e
8 changed files with 154 additions and 54 deletions

View file

@ -1,7 +1,15 @@
use axum::response::{IntoResponse, Redirect}; use axum::response::{IntoResponse, Redirect};
use crate::AuthContext;
pub async fn handle_slash_redir() -> impl IntoResponse { pub async fn handle_slash_redir() -> impl IntoResponse {
Redirect::temporary("/") Redirect::temporary("/")
} }
pub async fn handle_slash() -> impl IntoResponse {} pub async fn handle_slash(auth: AuthContext) -> impl IntoResponse {
if let Some(user) = auth.current_user {
tracing::debug!("Logged in as: {user}");
} else {
tracing::debug!("Not logged in.")
}
}

View file

@ -1,11 +1,17 @@
#[macro_use] #[macro_use]
extern crate justerror; extern crate justerror;
use axum_login::SqliteStore;
pub use users::User;
use uuid::Uuid;
pub mod db; pub mod db;
pub mod generic_handlers; pub mod generic_handlers;
pub mod login; pub mod login;
pub mod session_store; pub mod session_store;
pub mod signup; pub mod signup;
pub(crate) mod templates; pub(crate) mod templates;
pub mod users;
pub(crate) mod util;
pub use signup::User; pub type AuthContext = axum_login::extractors::AuthContext<Uuid, User, SqliteStore<User>>;

View file

@ -1,26 +1,26 @@
use argon2::PasswordVerifier; use argon2::{
password_hash::{PasswordHash, PasswordVerifier},
Argon2,
};
use axum::{ use axum::{
extract::State, extract::State,
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Redirect, Response},
Form,
}; };
use axum_login::{secrecy::SecretVec, AuthUser, SqliteStore};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use uuid::Uuid;
use crate::{templates::LoginGet, User}; use crate::{
templates::{LoginGet, LoginPost},
util::form_decode,
AuthContext, User,
};
pub type AuthContext = axum_login::extractors::AuthContext<Uuid, User, SqliteStore<User>>; //-************************************************************************
// Constants
//-************************************************************************
impl AuthUser<Uuid> for User { const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1";
fn get_id(&self) -> Uuid {
self.id
}
fn get_password_hash(&self) -> SecretVec<u8> {
SecretVec::new(self.pwhash.as_bytes().to_vec())
}
}
//-************************************************************************ //-************************************************************************
// Login error and success types // Login error and success types
@ -32,7 +32,9 @@ pub struct LoginError(#[from] LoginErrorKind);
#[Error] #[Error]
#[non_exhaustive] #[non_exhaustive]
pub enum LoginErrorKind { pub enum LoginErrorKind {
Internal,
BadPassword, BadPassword,
BadUsername,
Unknown, Unknown,
} }
@ -58,8 +60,38 @@ impl IntoResponse for LoginError {
pub async fn post_login( pub async fn post_login(
mut auth: AuthContext, mut auth: AuthContext,
State(pool): State<SqlitePool>, State(pool): State<SqlitePool>,
) -> Result<(), LoginError> { Form(login): Form<LoginPost>,
Err(LoginErrorKind::Unknown.into()) ) -> Result<impl IntoResponse, LoginError> {
let username = form_decode(&login.username, LoginErrorKind::BadUsername)?;
let username = username.trim();
let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?;
let pw = pw.trim();
let user = User::get(username, &pool)
.await
.map_err(|_| LoginErrorKind::Unknown)?;
let verifier = Argon2::default();
let hash = PasswordHash::new(&user.pwhash).map_err(|_| LoginErrorKind::Internal)?;
match verifier.verify_password(pw.as_bytes(), &hash) {
Ok(_) => {
// log them in and set a session cookie
auth.login(&user)
.await
.map_err(|_| LoginErrorKind::Internal)?;
// update last_seen; maybe this is ok to fail?
sqlx::query(LAST_SEEN_QUERY)
.bind(user.id)
.execute(&pool)
.await
.map_err(|_| LoginErrorKind::Internal)?;
Ok(Redirect::temporary("/"))
}
_ => Err(LoginErrorKind::BadPassword.into()),
}
} }
pub async fn get_login() -> impl IntoResponse { pub async fn get_login() -> impl IntoResponse {

View file

@ -1,7 +1,7 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use axum_login::axum_sessions::SessionLayer; use axum_login::{axum_sessions::SessionLayer, AuthLayer, SqliteStore};
use rand_core::{OsRng, RngCore}; use rand_core::{OsRng, RngCore};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use witch_watch::{ use witch_watch::{
@ -10,6 +10,7 @@ use witch_watch::{
login::{get_login, post_login}, login::{get_login, post_login},
session_store::SqliteSessionStore, session_store::SqliteSessionStore,
signup::{get_create_user, handle_signup_success, post_create_user}, signup::{get_create_user, handle_signup_success, post_create_user},
User,
}; };
#[tokio::main] #[tokio::main]
@ -23,18 +24,26 @@ async fn main() {
.init(); .init();
let pool = db::get_pool().await; let pool = db::get_pool().await;
let session_layer = {
let store = SqliteSessionStore::from_client(pool.clone());
store.migrate().await.expect("Could not migrate session DB");
let secret = { let secret = {
let mut bytes = [0u8; 128]; let mut bytes = [0u8; 128];
let mut rng = OsRng; let mut rng = OsRng;
rng.fill_bytes(&mut bytes); rng.fill_bytes(&mut bytes);
bytes bytes
}; };
let session_layer = {
let store = SqliteSessionStore::from_client(pool.clone());
store.migrate().await.expect("Could not migrate session DB");
SessionLayer::new(store, &secret).with_secure(true) SessionLayer::new(store, &secret).with_secure(true)
}; };
let auth_layer = {
const QUERY: &str = "select * from witches where id = $1";
let store = SqliteStore::<User>::new(pool.clone()).with_query(QUERY);
AuthLayer::new(store, &secret)
};
let app = Router::new() let app = Router::new()
.route("/", get(handle_slash).post(handle_slash)) .route("/", get(handle_slash).post(handle_slash))
.route("/signup", get(get_create_user).post(post_create_user)) .route("/signup", get(get_create_user).post(post_create_user))
@ -44,6 +53,7 @@ async fn main() {
) )
.route("/login", get(get_login).post(post_login)) .route("/login", get(get_login).post(post_login))
.fallback(handle_slash_redir) .fallback(handle_slash_redir)
.layer(auth_layer)
.layer(session_layer) .layer(session_layer)
.with_state(pool); .with_state(pool);

View file

@ -1,5 +1,3 @@
use std::fmt::Display;
use argon2::{ use argon2::{
password_hash::{rand_core::OsRng, PasswordHasher, SaltString}, password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
Argon2, Argon2,
@ -14,35 +12,12 @@ use sqlx::{query_as, SqlitePool};
use unicode_segmentation::UnicodeSegmentation; use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid; use uuid::Uuid;
use crate::templates::CreateUser; use crate::{templates::CreateUser, User};
const CREATE_QUERY: &str = const CREATE_QUERY: &str =
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)"; "insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
const ID_QUERY: &str = "select * from witches where id = $1"; const ID_QUERY: &str = "select * from witches where id = $1";
#[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow)]
pub struct User {
pub id: Uuid,
pub username: String,
pub displayname: Option<String>,
pub email: Option<String>,
pub last_seen: Option<i64>,
pub(crate) pwhash: String,
}
impl Display for User {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let uname = &self.username;
let dname = if let Some(ref n) = self.displayname {
n
} else {
""
};
let email = if let Some(ref e) = self.email { e } else { "" };
write!(f, "Username: {uname}\nDisplayname: {dname}\nEmail: {email}")
}
}
//-************************************************************************ //-************************************************************************
// Result types for user creation // Result types for user creation
//-************************************************************************ //-************************************************************************
@ -97,7 +72,7 @@ pub async fn get_create_user() -> CreateUser {
pub async fn post_create_user( pub async fn post_create_user(
State(pool): State<SqlitePool>, State(pool): State<SqlitePool>,
Form(signup): Form<CreateUser>, Form(signup): Form<CreateUser>,
) -> Result<Response, CreateUserError> { ) -> Result<impl IntoResponse, CreateUserError> {
let username = &signup.username; let username = &signup.username;
let displayname = &signup.displayname; let displayname = &signup.displayname;
let email = &signup.email; let email = &signup.email;
@ -155,7 +130,7 @@ pub async fn post_create_user(
let id = user.id.as_simple().to_string(); let id = user.id.as_simple().to_string();
let location = format!("/signup_success/{id}"); let location = format!("/signup_success/{id}");
let resp = axum::response::Redirect::temporary(&location).into_response(); let resp = axum::response::Redirect::temporary(&location);
Ok(resp) Ok(resp)
} }

49
src/users.rs Normal file
View file

@ -0,0 +1,49 @@
use std::fmt::Display;
use axum_login::{secrecy::SecretVec, AuthUser};
use sqlx::SqlitePool;
use uuid::Uuid;
const USERNAME_QUERY: &str = "select * from witches where username = $1";
#[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow)]
pub struct User {
pub id: Uuid,
pub username: String,
pub displayname: Option<String>,
pub email: Option<String>,
pub last_seen: Option<i64>,
pub(crate) pwhash: String,
}
impl Display for User {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let uname = &self.username;
let dname = if let Some(ref n) = self.displayname {
n
} else {
""
};
let email = if let Some(ref e) = self.email { e } else { "" };
write!(f, "Username: {uname}\nDisplayname: {dname}\nEmail: {email}")
}
}
impl AuthUser<Uuid> for User {
fn get_id(&self) -> Uuid {
self.id
}
fn get_password_hash(&self) -> SecretVec<u8> {
SecretVec::new(self.pwhash.as_bytes().to_vec())
}
}
impl User {
pub async fn get(username: &str, db: &SqlitePool) -> Result<User, impl std::error::Error> {
sqlx::query_as(USERNAME_QUERY)
.bind(username)
.fetch_one(db)
.await
}
}

3
src/util.rs Normal file
View file

@ -0,0 +1,3 @@
pub fn form_decode<E: std::error::Error>(input: &str, err: E) -> Result<String, E> {
Ok(urlencoding::decode(input).map_err(|_| err)?.into_owned())
}

View file

@ -0,0 +1,17 @@
{% extends "base.html" %}
{% block title %}Login to Witch Watch, Bish{% endblock %}
{% block content %}
<p>
<form action="/login" enctype="application/x-www-form-urlencoded" method="post">
<label for="username">Username</label>
<input type="text" name="username" id="username" minlength="1" maxlength="20" required></br>
<label for="password">Password</label>
<input type="password" name="password" id="password" required></br>
<input type="submit" value="Signup">
</form>
</p>
{% endblock %}