make valid ranges into constants
This commit is contained in:
parent
c7bad9350a
commit
1daec80430
2 changed files with 21 additions and 12 deletions
|
@ -1,4 +1,4 @@
|
|||
use std::{error::Error, fmt::Debug, ops::Range};
|
||||
use std::{error::Error, fmt::Debug, ops::RangeInclusive};
|
||||
|
||||
use axum::{
|
||||
extract::{Form, Path},
|
||||
|
@ -13,6 +13,11 @@ use crate::{templates::*, User};
|
|||
|
||||
const SIGNUP_KEY: &str = "meow";
|
||||
|
||||
const PASSWORD_LEN: RangeInclusive<usize> = 4..=100;
|
||||
const USERNAME_LEN: RangeInclusive<usize> = 1..=50;
|
||||
const DISPLAYNAME_LEN: RangeInclusive<usize> = 0..=100;
|
||||
const EMAIL_LEN: RangeInclusive<usize> = 4..=50;
|
||||
|
||||
#[Error(desc = "Could not create user.")]
|
||||
#[non_exhaustive]
|
||||
pub struct CreateUserError(#[from] CreateUserErrorKind);
|
||||
|
@ -27,10 +32,10 @@ impl IntoResponse for CreateUserError {
|
|||
#[non_exhaustive]
|
||||
pub enum CreateUserErrorKind {
|
||||
AlreadyExists,
|
||||
#[error(desc = "Usernames must be between 1 and 20 characters long")]
|
||||
#[error(desc = "Usernames must be between 1 and 50 characters long")]
|
||||
BadUsername,
|
||||
PasswordMismatch,
|
||||
#[error(desc = "Password must have at least 4 and at most 50 characters")]
|
||||
#[error(desc = "Password must have at least 4 and at most 100 characters")]
|
||||
BadPassword,
|
||||
#[error(desc = "Display name must be less than 100 characters long")]
|
||||
BadDisplayname,
|
||||
|
@ -60,7 +65,7 @@ pub async fn post_signup(
|
|||
session: Session,
|
||||
Form(form): Form<SignupForm>,
|
||||
) -> Result<impl IntoResponse, CreateUserError> {
|
||||
let user = verify_user(&form).await?;
|
||||
let user = validate_signup(&form).await?;
|
||||
session.insert(SIGNUP_KEY, user).await.unwrap();
|
||||
|
||||
Ok(Redirect::to(
|
||||
|
@ -96,14 +101,14 @@ pub async fn signup_success(session: Session, receipt: Option<Path<String>>) ->
|
|||
//-************************************************************************
|
||||
// helpers
|
||||
//-************************************************************************
|
||||
async fn verify_user(form: &SignupForm) -> Result<User, CreateUserError> {
|
||||
async fn validate_signup(form: &SignupForm) -> Result<User, CreateUserError> {
|
||||
let username = form.username.trim();
|
||||
let password = form.password.trim();
|
||||
let verify = form.pw_verify.trim();
|
||||
|
||||
let name_len = username.graphemes(true).size_hint().1.unwrap();
|
||||
// we are not ascii exclusivists around here
|
||||
if !(1..=20).contains(&name_len) {
|
||||
if !USERNAME_LEN.contains(&name_len) {
|
||||
return Err(CreateUserErrorKind::BadUsername.into());
|
||||
}
|
||||
|
||||
|
@ -111,18 +116,18 @@ async fn verify_user(form: &SignupForm) -> Result<User, CreateUserError> {
|
|||
return Err(CreateUserErrorKind::PasswordMismatch.into());
|
||||
}
|
||||
let pwlen = password.graphemes(true).size_hint().1.unwrap_or(0);
|
||||
if !(4..=50).contains(&pwlen) {
|
||||
if !PASSWORD_LEN.contains(&pwlen) {
|
||||
return Err(CreateUserErrorKind::BadPassword.into());
|
||||
}
|
||||
|
||||
// clean up the optionals
|
||||
let displayname = validate_optional_length(
|
||||
&form.displayname,
|
||||
0..100,
|
||||
DISPLAYNAME_LEN,
|
||||
CreateUserErrorKind::BadDisplayname,
|
||||
)?;
|
||||
|
||||
let email = validate_length(&form.email, 5..30, CreateUserErrorKind::BadEmail)?;
|
||||
let email = validate_length(&form.email, EMAIL_LEN, CreateUserErrorKind::BadEmail)?;
|
||||
|
||||
let user = User {
|
||||
username: username.to_string(),
|
||||
|
@ -152,7 +157,7 @@ where
|
|||
|
||||
pub(crate) fn validate_optional_length<E: Error>(
|
||||
opt: &Option<String>,
|
||||
len_range: Range<usize>,
|
||||
len_range: RangeInclusive<usize>,
|
||||
err: E,
|
||||
) -> Result<Option<String>, E> {
|
||||
if let Some(opt) = opt {
|
||||
|
@ -170,7 +175,7 @@ pub(crate) fn validate_optional_length<E: Error>(
|
|||
|
||||
pub(crate) fn validate_length<E: Error>(
|
||||
thing: &str,
|
||||
len_range: Range<usize>,
|
||||
len_range: RangeInclusive<usize>,
|
||||
err: E,
|
||||
) -> Result<String, E> {
|
||||
let thing = thing.trim();
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::{
|
|||
|
||||
use axum::{routing::get, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_http::services::ServeDir;
|
||||
use tower_sessions::{MemoryStore, SessionManagerLayer};
|
||||
|
||||
#[macro_use]
|
||||
|
@ -20,8 +21,11 @@ async fn main() {
|
|||
let session_store = MemoryStore::default();
|
||||
let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
|
||||
|
||||
let assets_dir = std::env::current_dir().unwrap().join("assets");
|
||||
let assets_svc = ServeDir::new(assets_dir.as_path());
|
||||
|
||||
let app = Router::new()
|
||||
//.nest_service("/assets", assets_svc)
|
||||
.nest_service("/assets", assets_svc)
|
||||
.route("/signup", get(get_signup).post(post_signup))
|
||||
.route("/signup_success/:receipt", get(signup_success))
|
||||
.layer(session_layer);
|
||||
|
|
Loading…
Reference in a new issue