diff --git a/Cargo.lock b/Cargo.lock index 7b83a10..cb98297 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1522,6 +1522,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_test" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "797c38160e2546a56e1e3439496439597e938669673ffd8af02a12f070da648f" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1693,7 +1702,6 @@ dependencies = [ "time", "tokio-stream", "url", - "uuid", "webpki-roots", ] @@ -2105,12 +2113,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "uuid" -version = "1.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" - [[package]] name = "valuable" version = "0.1.0" @@ -2341,6 +2343,7 @@ dependencies = [ "password-hash", "rand_core", "serde", + "serde_test", "sqlx", "thiserror", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 954e241..7ca6f86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tower = { version = "0.4", features = ["util", "timeout"], default-features = false } tower-http = { version = "0.4", features = ["add-extension", "trace"] } serde = { version = "1", features = ["derive"] } -sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time", "uuid"] } +sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time"] } argon2 = "0.5" rand_core = { version = "0.6", features = ["getrandom"] } thiserror = "1" @@ -30,4 +30,5 @@ optional_optional_user = {path = "optional_optional_user"} [dev-dependencies] axum-test = "9.0.0" +serde_test = "1.0.164" diff --git a/src/db.rs b/src/db.rs index 4cbe8ae..70540b6 100644 --- a/src/db.rs +++ b/src/db.rs @@ -12,7 +12,7 @@ use sqlx::{ SqlitePool, }; -use crate::{ids::DbId, User}; +use crate::{db_id::DbId, User}; const MAX_CONNS: u32 = 200; const MIN_CONNS: u32 = 5; diff --git a/src/db_id.rs b/src/db_id.rs new file mode 100644 index 0000000..5b7e2f6 --- /dev/null +++ b/src/db_id.rs @@ -0,0 +1,251 @@ +use std::{ + borrow::Cow, + fmt::{Debug, Display}, +}; + +use serde::{de::Visitor, Deserialize, Serialize}; +use sqlx::{ + encode::IsNull, + sqlite::{SqliteArgumentValue, SqliteValueRef}, + Decode, Encode, Sqlite, +}; +use ulid::Ulid; + +#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct DbId(pub Ulid); + +impl DbId { + pub fn bytes(&self) -> [u8; 16] { + self.as_be_bytes() + } + + pub fn as_be_bytes(&self) -> [u8; 16] { + self.0 .0.to_be_bytes() + } + + pub fn is_nil(&self) -> bool { + self.0.is_nil() + } + + pub fn new() -> Self { + let id = Ulid::new(); + Self(id) + } + + pub fn from_str(s: &str) -> Result { + let id = Ulid::from_string(s)?; + Ok(id.into()) + } +} + +//-************************************************************************ +// standard trait impls +//-************************************************************************ + +impl Display for DbId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0.to_string()) + } +} + +impl Debug for DbId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("DbId").field(&self.bytes()).finish() + } +} + +impl From for DbId { + fn from(value: Ulid) -> Self { + DbId(value) + } +} + +impl From for DbId { + fn from(value: u128) -> Self { + DbId(value.into()) + } +} + +//-************************************************************************ +// sqlx traits for going in and out of the db +//-************************************************************************ + +impl sqlx::Type for DbId { + fn type_info() -> ::TypeInfo { + <&[u8] as sqlx::Type>::type_info() + } +} + +// sqlx traits for marshalling in and out +impl<'q> Encode<'q, Sqlite> for DbId { + fn encode(self, args: &mut Vec>) -> IsNull { + args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec()))); + IsNull::No + } + + fn encode_by_ref(&self, args: &mut Vec>) -> IsNull { + args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec()))); + IsNull::No + } + + fn produces(&self) -> Option<::TypeInfo> { + // `produces` is inherently a hook to allow database drivers to produce + // value-dependent type information; if the driver doesn't need this, it + // can leave this as `None` + None + } + + fn size_hint(&self) -> usize { + std::mem::size_of_val(self) + } +} + +impl<'r> Decode<'r, Sqlite> for DbId { + fn decode(value: SqliteValueRef<'r>) -> Result { + let bytes = <&[u8] as Decode>::decode(value)?; + let bytes: [u8; 16] = bytes.try_into().unwrap_or_default(); + let id: Ulid = u128::from_ne_bytes(bytes).into(); + Ok(id.into()) + } +} + +//-************************************************************************ +// serde traits +//-************************************************************************ +impl Serialize for DbId { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.bytes()) + } +} + +struct DbIdVisitor; + +impl<'de> Visitor<'de> for DbIdVisitor { + type Value = DbId; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a 128-bit number") + } + + fn visit_i128(self, v: i128) -> Result + where + E: serde::de::Error, + { + Ok(DbId(Ulid(v as u128))) + } + + fn visit_u128(self, v: u128) -> Result + where + E: serde::de::Error, + { + Ok(DbId(Ulid(v))) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + match std::convert::TryInto::<[u8; 16]>::try_into(v) { + Ok(v) => Ok(DbId(Ulid(u128::from_be_bytes(v)))), + Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)), + } + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: serde::de::Error, + { + let len = v.len(); + match std::convert::TryInto::<[u8; 16]>::try_into(v) { + Ok(v) => Ok(DbId(Ulid(u128::from_be_bytes(v)))), + Err(_) => Err(serde::de::Error::invalid_length(len, &self)), + } + } + + fn visit_string(self, v: String) -> Result + where + E: serde::de::Error, + { + match Ulid::from_string(&v) { + Ok(v) => Ok(DbId(v)), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")), + &self, + )), + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match Ulid::from_string(v) { + Ok(v) => Ok(DbId(v)), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")), + &self, + )), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut bytes = [0u8; 16]; + let size = seq.size_hint().unwrap_or(0); + let mut count = 0; + while let Some(val) = seq.next_element()? { + if count >= 16 { + break; + } + bytes[count] = val; + count += 1; + } + if count != 16 || size > 16 { + let sz = if count < 16 { count } else { size }; + Err(serde::de::Error::invalid_length(sz, &self)) + } else { + let id = u128::from_ne_bytes(bytes); + Ok(id.into()) + } + } +} + +impl<'de> Deserialize<'de> for DbId { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_bytes(DbIdVisitor) + } +} + +//-************************************************************************ +// serialization tests +//-************************************************************************ + +#[cfg(test)] +mod test { + use serde_test::{assert_tokens, Token}; + + use super::*; + + #[test] + fn test_ser_de() { + let bytes: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; + let be_num = u128::from_be_bytes(bytes); + let le_ulid = Ulid(be_num); + let le_id = DbId(le_ulid); + + assert_tokens( + &le_id, + &[Token::Bytes(&[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + ])], + ); + } +} diff --git a/src/ids.rs b/src/ids.rs deleted file mode 100644 index b8d4912..0000000 --- a/src/ids.rs +++ /dev/null @@ -1,143 +0,0 @@ -use std::{borrow::Cow, fmt::Display}; - -use serde::{de::Visitor, Deserialize, Serialize}; -use sqlx::{ - encode::IsNull, - sqlite::{SqliteArgumentValue, SqliteValueRef}, - Decode, Encode, Sqlite, -}; -use ulid::Ulid; - -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct DbId(pub Ulid); - -impl From for DbId { - fn from(value: Ulid) -> Self { - DbId(value) - } -} - -impl sqlx::Type for DbId { - fn type_info() -> ::TypeInfo { - <&[u8] as sqlx::Type>::type_info() - } -} - -impl<'q> Encode<'q, Sqlite> for DbId { - fn encode(self, args: &mut Vec>) -> IsNull { - args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec()))); - IsNull::No - } - - fn encode_by_ref(&self, args: &mut Vec>) -> IsNull { - args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec()))); - - IsNull::No - } - - fn produces(&self) -> Option<::TypeInfo> { - // `produces` is inherently a hook to allow database drivers to produce - // value-dependent type information; if the driver doesn't need this, it - // can leave this as `None` - None - } - - fn size_hint(&self) -> usize { - std::mem::size_of_val(self) - } -} - -impl<'r> Decode<'r, Sqlite> for DbId { - fn decode(value: SqliteValueRef<'r>) -> Result { - let bytes = <&[u8] as Decode>::decode(value)?; - let bytes: [u8; 16] = bytes.try_into().unwrap_or_default(); - let id: Ulid = u128::from_ne_bytes(bytes).into(); - Ok(id.into()) - } -} - -impl Display for DbId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0.to_string()) - } -} - -impl DbId { - pub fn bytes(&self) -> [u8; 16] { - self.0 .0.to_ne_bytes() - } - - pub fn is_nil(&self) -> bool { - self.0.is_nil() - } - - pub fn new() -> Self { - let id = Ulid::new(); - Self(id) - } -} - -impl Serialize for DbId { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_bytes(&self.0 .0.to_ne_bytes()) - } -} - -struct IdVisitor; - -impl<'de> Visitor<'de> for IdVisitor { - type Value = DbId; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a 128-bit number") - } - - fn visit_i128(self, v: i128) -> Result - where - E: serde::de::Error, - { - Ok(DbId(Ulid(v as u128))) - } - - fn visit_u128(self, v: u128) -> Result - where - E: serde::de::Error, - { - Ok(DbId(Ulid(v))) - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - match std::convert::TryInto::<[u8; 16]>::try_into(v) { - Ok(v) => Ok(DbId(Ulid(u128::from_ne_bytes(v)))), - Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)), - } - } - - fn visit_string(self, v: String) -> Result - where - E: serde::de::Error, - { - match Ulid::from_string(&v) { - Ok(v) => Ok(DbId(v)), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")), - &self, - )), - } - } -} - -impl<'de> Deserialize<'de> for DbId { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_bytes(IdVisitor) - } -} diff --git a/src/lib.rs b/src/lib.rs index ce5fb1c..35a3d5b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,8 +9,8 @@ pub use db::get_db_pool; // everything else is private to the crate mod db; +mod db_id; mod generic_handlers; -mod ids; mod login; mod signup; mod templates; @@ -19,7 +19,7 @@ mod util; mod watches; // things we want in the crate namespace -use ids::DbId; +use db_id::DbId; use optional_optional_user::OptionalOptionalUser; use templates::*; use users::User; diff --git a/src/login.rs b/src/login.rs index 33e9567..aac8edf 100644 --- a/src/login.rs +++ b/src/login.rs @@ -206,8 +206,8 @@ mod test { } .to_string(); - let idx = s.get("/").await; - let body = std::str::from_utf8(idx.bytes()).unwrap(); + let main_page = s.get("/").await; + let body = std::str::from_utf8(main_page.bytes()).unwrap(); assert_eq!(&logged_in, body); } diff --git a/src/signup.rs b/src/signup.rs index 40c879c..28005e8 100644 --- a/src/signup.rs +++ b/src/signup.rs @@ -7,11 +7,11 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; +use serde::Deserialize; use sqlx::{query_as, SqlitePool}; -use ulid::Ulid; use unicode_segmentation::UnicodeSegmentation; -use crate::{CreateUser, CreateUserSuccess, DbId, User}; +use crate::{util::empty_string_as_none, CreateUserSuccess, DbId, SignupPage, User}; pub(crate) const CREATE_QUERY: &str = "insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)"; @@ -51,13 +51,24 @@ pub enum CreateUserErrorKind { UnknownDBError, } +#[derive(Debug, Default, Deserialize, PartialEq, Eq)] +pub struct SignupForm { + pub username: String, + #[serde(default, deserialize_with = "empty_string_as_none")] + pub displayname: Option, + #[serde(default, deserialize_with = "empty_string_as_none")] + pub email: Option, + pub password: String, + pub pw_verify: String, +} + //-************************************************************************ // User creation route handlers //-************************************************************************ /// Get Handler: displays the form to create a user -pub async fn get_create_user() -> CreateUser { - CreateUser::default() +pub async fn get_create_user() -> SignupPage { + SignupPage::default() } /// Post Handler: validates form values and calls the actual, private user @@ -65,7 +76,7 @@ pub async fn get_create_user() -> CreateUser { #[axum::debug_handler] pub async fn post_create_user( State(pool): State, - Form(signup): Form, + Form(signup): Form, ) -> Result { use crate::util::validate_optional_length; let username = signup.username.trim(); @@ -114,10 +125,10 @@ pub async fn get_signup_success( State(pool): State, ) -> Response { let id = id.trim(); + let id = DbId::from_str(id).unwrap_or_default(); let user: User = { - let id: DbId = Ulid::from_string(id).unwrap_or_default().into(); query_as(ID_QUERY) - .bind(id.bytes().as_slice()) + .bind(id) .fetch_one(&pool) .await .unwrap_or_default() @@ -125,7 +136,7 @@ pub async fn get_signup_success( let mut resp = CreateUserSuccess(user.clone()).into_response(); - if user.username.is_empty() || id.is_empty() { + if user.username.is_empty() || id.is_nil() { // redirect to front page if we got here without a valid witch ID *resp.status_mut() = StatusCode::SEE_OTHER; resp.headers_mut().insert("Location", "/".parse().unwrap()); @@ -154,9 +165,8 @@ pub(crate) async fn create_user( .unwrap() // safe to unwrap, we know the salt is valid .to_string(); - let bytes = &id.bytes(); let query = sqlx::query(CREATE_QUERY) - .bind(bytes.as_slice()) + .bind(id) .bind(username) .bind(displayname) .bind(email) @@ -203,7 +213,7 @@ mod test { use crate::{ db::get_db_pool, - templates::{CreateUser, CreateUserSuccess}, + templates::{CreateUserSuccess, SignupPage}, test_utils::{get_test_user, insert_user, massage, server_with_pool, FORM_CONTENT_TYPE}, User, }; @@ -237,7 +247,7 @@ mod test { let resp = server.get("/signup").await; let body = std::str::from_utf8(resp.bytes()).unwrap(); - let expected = CreateUser::default().to_string(); + let expected = SignupPage::default().to_string(); assert_eq!(&expected, body); } diff --git a/src/templates.rs b/src/templates.rs index 778d96d..a01d23f 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -5,7 +5,7 @@ use crate::{OptionalOptionalUser, User}; #[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq, OptionalOptionalUser)] #[template(path = "signup.html")] -pub struct CreateUser { +pub struct SignupPage { pub username: String, pub displayname: Option, pub email: Option, diff --git a/src/test_utils.rs b/src/test_utils.rs index e3657eb..079be4e 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -11,7 +11,7 @@ pub fn get_test_user() -> User { username: "test_user".to_string(), // corresponding to a password of "a": pwhash: "$argon2id$v=19$m=19456,t=2,p=1$GWsCH1w5RYaP9WWmq+xw0g$hmOEqC+MU+vnEk3bOdkoE+z01mOmmOeX08XyPyjqua8".to_string(), - id: DbId::default(), + id: DbId::from_str("00041061050R3GG28A1C60T3GF").unwrap(), displayname: Some("Test User".to_string()), ..Default::default() } diff --git a/src/users.rs b/src/users.rs index cb97cc9..c04d810 100644 --- a/src/users.rs +++ b/src/users.rs @@ -13,7 +13,7 @@ use crate::{AuthContext, DbId}; const USERNAME_QUERY: &str = "select * from witches where username = $1"; const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1"; -#[derive(Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)] +#[derive(Default, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)] pub struct User { pub id: DbId, pub username: String, diff --git a/src/util.rs b/src/util.rs index 022c0f8..dd16117 100644 --- a/src/util.rs +++ b/src/util.rs @@ -19,3 +19,19 @@ pub fn validate_optional_length( Ok(None) } } + +/// Serde deserialization decorator to map empty Strings to None, +pub fn empty_string_as_none<'de, D, T>(de: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, + T: std::str::FromStr, + T::Err: std::fmt::Display, +{ + let opt = as serde::Deserialize>::deserialize(de)?; + match opt.as_deref() { + None | Some("") => Ok(None), + Some(s) => std::str::FromStr::from_str(s) + .map_err(serde::de::Error::custom) + .map(Some), + } +} diff --git a/src/watches/handlers.rs b/src/watches/handlers.rs index 38e7c37..43e9403 100644 --- a/src/watches/handlers.rs +++ b/src/watches/handlers.rs @@ -3,12 +3,11 @@ use axum::{ http::StatusCode, response::{IntoResponse, Redirect, Response}, }; -use serde::{de, Deserialize, Deserializer}; +use serde::Deserialize; use sqlx::{query, query_as, SqlitePool}; -use ulid::Ulid; use super::templates::{AddNewWatchPage, GetWatchPage, SearchWatchesPage}; -use crate::{ids::DbId, AuthContext, MyWatchesPage, ShowKind, Watch}; +use crate::{db_id::DbId, util::empty_string_as_none, AuthContext, MyWatchesPage, ShowKind, Watch}; //-************************************************************************ // Constants @@ -178,7 +177,7 @@ pub async fn get_watch( "".to_string() }; let id = id.trim(); - let id: DbId = Ulid::from_string(id).unwrap_or_default().into(); + let id = DbId::from_str(id).unwrap_or_default(); let watch: Option = query_as(GET_WATCH_QUERY) .bind(id) .fetch_one(&pool) @@ -233,23 +232,3 @@ pub async fn get_search_watch( search, } } - -//-************************************************************************ -// helper fns -//-************************************************************************ - -/// Serde deserialization decorator to map empty Strings to None, -fn empty_string_as_none<'de, D, T>(de: D) -> Result, D::Error> -where - D: Deserializer<'de>, - T: std::str::FromStr, - T::Err: std::fmt::Display, -{ - let opt = Option::::deserialize(de)?; - match opt.as_deref() { - None | Some("") => Ok(None), - Some(s) => std::str::FromStr::from_str(s) - .map_err(de::Error::custom) - .map(Some), - } -}