From ba3e8625f6bdad77b77f38927082d4e5566d5ef5 Mon Sep 17 00:00:00 2001 From: Joe Ardent Date: Mon, 15 Jan 2024 13:10:13 -0800 Subject: [PATCH] add more invitation tests --- src/signup/handlers.rs | 26 +++++++------- src/signup/mod.rs | 79 +++++++++++++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/signup/handlers.rs b/src/signup/handlers.rs index ec5c866..c0ba1e9 100644 --- a/src/signup/handlers.rs +++ b/src/signup/handlers.rs @@ -15,8 +15,6 @@ use unicode_segmentation::UnicodeSegmentation; use super::{templates::*, Invitation}; use crate::{util::empty_string_as_none, User}; -const ID_QUERY: &str = "select * from users where id = $1"; - //-************************************************************************ // Error types for user creation //-************************************************************************ @@ -147,6 +145,7 @@ pub async fn get_signup_success( Path(id): Path, State(pool): State, ) -> Response { + const ID_QUERY: &str = "select * from users where id = ?"; let id = id.trim(); let id = Julid::from_str(id).unwrap_or_default(); let user: User = { @@ -250,6 +249,14 @@ async fn validate_invitation( if remaining < 1 { return Err(CreateUserErrorKind::BadInvitation); } + + if let Some(ts) = invitation.expires_at { + let now = chrono::Utc::now().timestamp(); + if ts < now { + return Err(CreateUserErrorKind::BadInvitation); + } + } + let _ = sqlx::query("update invites set remaining = ? where id = ?") .bind(remaining - 1) .bind(invitation.id) @@ -260,13 +267,6 @@ async fn validate_invitation( CreateUserErrorKind::UnknownDBError })?; - if let Some(ts) = invitation.expires_at { - let now = chrono::Utc::now().timestamp(); - if ts < now { - return Err(CreateUserErrorKind::BadInvitation); - } - } - Ok(invitation.owner) } @@ -383,7 +383,7 @@ mod test { fn used_up_invite() { let lucky1 = "username=lucky1&password=aaaa&pw_verify=aaaa&invitation=0000000000000000000000001A"; let lucky2 = "username=lucky2&password=aaaa&pw_verify=aaaa&invitation=0000000000000000000000001A"; - let unlucky = "username=lucky3&password=aaaa&pw_verify=aaaa&invitation=0000000000000000000000001A"; + let unlucky = "username=unlucky&password=aaaa&pw_verify=aaaa&invitation=0000000000000000000000001A"; let pool = get_db_pool(); let rt = Runtime::new().unwrap(); rt.block_on(async { @@ -444,8 +444,10 @@ mod test { std::thread::sleep(Duration::from_millis(100)); + let username = "too slow"; + let tooslow = - format!("username=tooslow&password=aaaa&pw_verify=aaaa&invitation={invite}"); + format!("username={username}&password=aaaa&pw_verify=aaaa&invitation={invite}"); let body = massage(&tooslow); let resp = server @@ -455,7 +457,7 @@ mod test { .bytes(body) .content_type(FORM_CONTENT_TYPE) .await; - let user = User::try_get("unlucky", &pool).await; + let user = User::try_get(username, &pool).await; assert!(user.is_ok() && user.unwrap().is_none()); let body = String::from_utf8(resp.as_bytes().to_vec()).unwrap(); diff --git a/src/signup/mod.rs b/src/signup/mod.rs index 64ef5d8..255ba74 100644 --- a/src/signup/mod.rs +++ b/src/signup/mod.rs @@ -9,17 +9,19 @@ pub mod templates; #[Error(desc = "Could not create user.")] #[non_exhaustive] +#[derive(PartialEq, Eq)] pub struct CreateInviteError(#[from] CreateInviteErrorKind); #[Error] #[non_exhaustive] +#[derive(PartialEq, Eq)] pub enum CreateInviteErrorKind { DBError, NoOwner, Unknown, } -#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)] pub struct Invitation { id: Julid, owner: Julid, @@ -27,6 +29,17 @@ pub struct Invitation { remaining: i16, } +impl Default for Invitation { + fn default() -> Self { + Self { + id: 0.into(), + owner: 0.into(), + expires_at: None, + remaining: 1, + } + } +} + impl Invitation { pub async fn commit(&self, db: &SqlitePool) -> Result { sqlx::query_scalar( @@ -39,46 +52,70 @@ impl Invitation { .await .map_err(|e| { tracing::debug!("Got error creating invite: {e}"); - match e { - sqlx::Error::Database(dbe) => { - let exit = dbe.code().unwrap_or_default().parse().unwrap_or(0); - // https://www.sqlite.org/rescode.html#constraint_foreignkey - if exit == 787u32 { - CreateInviteErrorKind::NoOwner - } else { - CreateInviteErrorKind::DBError - } + if let sqlx::Error::Database(e) = e { + let exit = e.code().unwrap_or_default().parse().unwrap_or(0); + // https://www.sqlite.org/rescode.html#constraint_foreignkey + if exit == 787u32 { + CreateInviteErrorKind::NoOwner + } else { + CreateInviteErrorKind::DBError } - _ => CreateInviteErrorKind::Unknown, + } else { + CreateInviteErrorKind::Unknown } })? - .ok_or(CreateInviteErrorKind::DBError.into()) + .ok_or(CreateInviteErrorKind::Unknown.into()) } pub fn new(owner: Julid) -> Self { Self { - id: Julid::alpha(), // stand-in value, will let the db fill it in owner, - expires_at: None, - remaining: 1, + ..Default::default() } } pub fn with_uses(&self, uses: u8) -> Self { Self { - id: self.id, - owner: self.owner, - expires_at: self.expires_at, remaining: uses as i16, + ..*self } } pub fn with_expires_in(&self, expires_in: Duration) -> Self { Self { - id: self.id, - owner: self.owner, expires_at: Some((chrono::Utc::now() + expires_in).timestamp()), - remaining: self.remaining, + ..*self } } } + +#[cfg(test)] +mod test { + use tokio::runtime::Runtime; + + use super::*; + use crate::{get_db_pool, User}; + + #[test] + fn can_create() { + let pool = get_db_pool(); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + User::omega().try_insert(&pool).await.unwrap(); + let invite = Invitation::new(Julid::omega()); + invite.commit(&pool).await.unwrap(); + }); + } + + #[test] + fn bad_owner() { + let pool = get_db_pool(); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + User::omega().try_insert(&pool).await.unwrap(); + let invite = Invitation::new(Julid::alpha()); + let res = invite.commit(&pool).await; + assert_eq!(res, Err(CreateInviteErrorKind::NoOwner.into())); + }); + } +}