From 8b1eef17f77fba1323f306741727b5c356c81414 Mon Sep 17 00:00:00 2001 From: Joe Ardent Date: Sat, 10 Jun 2023 15:30:36 -0700 Subject: [PATCH] Simplify form input handling. The form input values were coming to the handlers already-urldecoded, so no need to re-decode them. --- Cargo.lock | 7 ------- Cargo.toml | 1 - src/db.rs | 4 ++++ src/login.rs | 7 +++---- src/signup.rs | 49 ++++++++++++------------------------------------- src/users.rs | 17 +++++++++++++++-- src/util.rs | 19 +++++++++++++++++-- 7 files changed, 51 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6acaf5d..5d15bd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2070,12 +2070,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "urlencoding" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" - [[package]] name = "uuid" version = "1.3.3" @@ -2390,7 +2384,6 @@ dependencies = [ "tracing", "tracing-subscriber", "unicode-segmentation", - "urlencoding", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 800803f..27e5949 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ justerror = "1" password-hash = { version = "0.5", features = ["std", "getrandom"] } axum-login = { version = "0.5", features = ["sqlite", "sqlx"] } unicode-segmentation = "1" -urlencoding = "2" async-session = "3" [dev-dependencies] diff --git a/src/db.rs b/src/db.rs index 7e9d39e..6129518 100644 --- a/src/db.rs +++ b/src/db.rs @@ -60,6 +60,10 @@ pub async fn get_db_pool() -> SqlitePool { .await .expect("can't connect to database"); + // let the filesystem settle before trying anything + // possibly not effective? + tokio::time::sleep(Duration::from_millis(500)).await; + { let mut m = Migrator::new(std::path::Path::new("./migrations")) .await diff --git a/src/login.rs b/src/login.rs index 5467de1..a2dea90 100644 --- a/src/login.rs +++ b/src/login.rs @@ -10,7 +10,7 @@ use axum::{ }; use sqlx::SqlitePool; -use crate::{util::form_decode, AuthContext, LoginGet, LoginPost, LogoutGet, LogoutPost, User}; +use crate::{AuthContext, LoginGet, LoginPost, LogoutGet, LogoutPost, User}; //-************************************************************************ // Constants @@ -28,7 +28,6 @@ pub struct LoginError(#[from] LoginErrorKind); pub enum LoginErrorKind { Internal, BadPassword, - BadUsername, Unknown, } @@ -57,10 +56,10 @@ pub async fn post_login( State(pool): State, Form(login): Form, ) -> Result { - let username = form_decode(&login.username, LoginErrorKind::BadUsername)?; + let username = &login.username; let username = username.trim(); - let pw = form_decode(&login.password, LoginErrorKind::BadPassword)?; + let pw = &login.password; let pw = pw.trim(); let user = User::try_get(username, &pool) diff --git a/src/signup.rs b/src/signup.rs index fb1e081..171a643 100644 --- a/src/signup.rs +++ b/src/signup.rs @@ -67,16 +67,11 @@ pub async fn post_create_user( State(pool): State, Form(signup): Form, ) -> Result { - let username = &signup.username; - let displayname = &signup.displayname; - let email = &signup.email; - let password = &signup.password; - let verify = &signup.pw_verify; + use crate::util::validate_optional_length; + let username = signup.username.trim(); + let password = signup.password.trim(); + let verify = signup.pw_verify.trim(); - let username = urlencoding::decode(username) - .map_err(|_| CreateUserErrorKind::BadUsername)? - .to_string(); - let username = username.trim(); let name_len = username.graphemes(true).size_hint().1.unwrap(); // we are not ascii exclusivists around here if !(1..=20).contains(&name_len) { @@ -87,42 +82,22 @@ pub async fn post_create_user( return Err(CreateUserErrorKind::PasswordMismatch.into()); } - let password = urlencoding::decode(password) - .map_err(|_| CreateUserErrorKind::BadPassword)? - .to_string(); let password = password.trim(); let password = password.as_bytes(); if !(4..=50).contains(&password.len()) { return Err(CreateUserErrorKind::BadPassword.into()); } - let displayname = if let Some(dn) = displayname { - let dn = urlencoding::decode(dn) - .map_err(|_| CreateUserErrorKind::BadDisplayname)? - .to_string() - .trim() - .to_string(); - if dn.graphemes(true).size_hint().1.unwrap() > 100 { - return Err(CreateUserErrorKind::BadDisplayname.into()); - } - Some(dn) - } else { - None - }; - let displayname = &displayname; + // clean up the optionals + let displayname = validate_optional_length( + &signup.displayname, + 0..100, + CreateUserErrorKind::BadDisplayname, + )?; - // TODO(2023-05-17): validate email - let email = if let Some(email) = email { - let email = urlencoding::decode(email) - .map_err(|_| CreateUserErrorKind::BadEmail)? - .to_string(); - Some(email) - } else { - None - }; - let email = &email; + let email = validate_optional_length(&signup.email, 5..30, CreateUserErrorKind::BadEmail)?; - let user = create_user(username, displayname, email, password, &pool).await?; + let user = create_user(username, &displayname, &email, password, &pool).await?; tracing::debug!("created {user:?}"); let id = user.id.as_simple().to_string(); let location = format!("/signup_success/{id}"); diff --git a/src/users.rs b/src/users.rs index 26f2dfd..433130c 100644 --- a/src/users.rs +++ b/src/users.rs @@ -1,5 +1,5 @@ use std::{ - fmt::Display, + fmt::{Debug, Display}, time::{SystemTime, UNIX_EPOCH}, }; @@ -14,7 +14,7 @@ use crate::AuthContext; const USERNAME_QUERY: &str = "select * from witches where username = $1"; const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1"; -#[derive(Debug, Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)] +#[derive(Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)] pub struct User { pub id: Uuid, pub username: String, @@ -24,6 +24,19 @@ pub struct User { pub(crate) pwhash: String, } +impl Debug for User { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("User") + .field("id", &self.id) + .field("username", &self.username) + .field("displayname", &self.displayname) + .field("email", &self.email) + .field("last_seen", &self.last_seen) + .field("pwhash", &"") + .finish() + } +} + impl Display for User { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let uname = &self.username; diff --git a/src/util.rs b/src/util.rs index 1b15b86..466660c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,3 +1,18 @@ -pub fn form_decode(input: &str, err: E) -> Result { - Ok(urlencoding::decode(input).map_err(|_| err)?.into_owned()) +use std::{error::Error, ops::Range}; + +pub fn validate_optional_length( + opt: &Option, + len_range: Range, + err: E, +) -> Result, E> { + if let Some(opt) = opt { + let opt = opt.trim(); + if !len_range.contains(&opt.len()) { + Err(err) + } else { + Ok(Some(opt.to_string())) + } + } else { + Ok(None) + } }