add working ULID-based IDs for primary keys
This commit is contained in:
parent
be96100237
commit
656e6dceed
13 changed files with 312 additions and 195 deletions
17
Cargo.lock
generated
17
Cargo.lock
generated
|
@ -1522,6 +1522,15 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_test"
|
||||||
|
version = "1.0.164"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "797c38160e2546a56e1e3439496439597e938669673ffd8af02a12f070da648f"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_urlencoded"
|
name = "serde_urlencoded"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
@ -1693,7 +1702,6 @@ dependencies = [
|
||||||
"time",
|
"time",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"url",
|
"url",
|
||||||
"uuid",
|
|
||||||
"webpki-roots",
|
"webpki-roots",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -2105,12 +2113,6 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "uuid"
|
|
||||||
version = "1.3.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "valuable"
|
name = "valuable"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -2341,6 +2343,7 @@ dependencies = [
|
||||||
"password-hash",
|
"password-hash",
|
||||||
"rand_core",
|
"rand_core",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_test",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
|
@ -14,7 +14,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
tower = { version = "0.4", features = ["util", "timeout"], default-features = false }
|
tower = { version = "0.4", features = ["util", "timeout"], default-features = false }
|
||||||
tower-http = { version = "0.4", features = ["add-extension", "trace"] }
|
tower-http = { version = "0.4", features = ["add-extension", "trace"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time", "uuid"] }
|
sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time"] }
|
||||||
argon2 = "0.5"
|
argon2 = "0.5"
|
||||||
rand_core = { version = "0.6", features = ["getrandom"] }
|
rand_core = { version = "0.6", features = ["getrandom"] }
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
|
@ -30,4 +30,5 @@ optional_optional_user = {path = "optional_optional_user"}
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
axum-test = "9.0.0"
|
axum-test = "9.0.0"
|
||||||
|
serde_test = "1.0.164"
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ use sqlx::{
|
||||||
SqlitePool,
|
SqlitePool,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{ids::DbId, User};
|
use crate::{db_id::DbId, User};
|
||||||
|
|
||||||
const MAX_CONNS: u32 = 200;
|
const MAX_CONNS: u32 = 200;
|
||||||
const MIN_CONNS: u32 = 5;
|
const MIN_CONNS: u32 = 5;
|
||||||
|
|
251
src/db_id.rs
Normal file
251
src/db_id.rs
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
use std::{
|
||||||
|
borrow::Cow,
|
||||||
|
fmt::{Debug, Display},
|
||||||
|
};
|
||||||
|
|
||||||
|
use serde::{de::Visitor, Deserialize, Serialize};
|
||||||
|
use sqlx::{
|
||||||
|
encode::IsNull,
|
||||||
|
sqlite::{SqliteArgumentValue, SqliteValueRef},
|
||||||
|
Decode, Encode, Sqlite,
|
||||||
|
};
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
|
pub struct DbId(pub Ulid);
|
||||||
|
|
||||||
|
impl DbId {
|
||||||
|
pub fn bytes(&self) -> [u8; 16] {
|
||||||
|
self.as_be_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_be_bytes(&self) -> [u8; 16] {
|
||||||
|
self.0 .0.to_be_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_nil(&self) -> bool {
|
||||||
|
self.0.is_nil()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let id = Ulid::new();
|
||||||
|
Self(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_str(s: &str) -> Result<Self, ulid::DecodeError> {
|
||||||
|
let id = Ulid::from_string(s)?;
|
||||||
|
Ok(id.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// standard trait impls
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
impl Display for DbId {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.0.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Debug for DbId {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_tuple("DbId").field(&self.bytes()).finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Ulid> for DbId {
|
||||||
|
fn from(value: Ulid) -> Self {
|
||||||
|
DbId(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<u128> for DbId {
|
||||||
|
fn from(value: u128) -> Self {
|
||||||
|
DbId(value.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// sqlx traits for going in and out of the db
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
impl sqlx::Type<sqlx::Sqlite> for DbId {
|
||||||
|
fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
|
||||||
|
<&[u8] as sqlx::Type<sqlx::Sqlite>>::type_info()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlx traits for marshalling in and out
|
||||||
|
impl<'q> Encode<'q, Sqlite> for DbId {
|
||||||
|
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
|
||||||
|
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
|
||||||
|
IsNull::No
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
|
||||||
|
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
|
||||||
|
IsNull::No
|
||||||
|
}
|
||||||
|
|
||||||
|
fn produces(&self) -> Option<<Sqlite as sqlx::Database>::TypeInfo> {
|
||||||
|
// `produces` is inherently a hook to allow database drivers to produce
|
||||||
|
// value-dependent type information; if the driver doesn't need this, it
|
||||||
|
// can leave this as `None`
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_hint(&self) -> usize {
|
||||||
|
std::mem::size_of_val(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'r> Decode<'r, Sqlite> for DbId {
|
||||||
|
fn decode(value: SqliteValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
|
||||||
|
let bytes = <&[u8] as Decode<Sqlite>>::decode(value)?;
|
||||||
|
let bytes: [u8; 16] = bytes.try_into().unwrap_or_default();
|
||||||
|
let id: Ulid = u128::from_ne_bytes(bytes).into();
|
||||||
|
Ok(id.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// serde traits
|
||||||
|
//-************************************************************************
|
||||||
|
impl Serialize for DbId {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_bytes(&self.bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DbIdVisitor;
|
||||||
|
|
||||||
|
impl<'de> Visitor<'de> for DbIdVisitor {
|
||||||
|
type Value = DbId;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
formatter.write_str("a 128-bit number")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
Ok(DbId(Ulid(v as u128)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
Ok(DbId(Ulid(v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
match std::convert::TryInto::<[u8; 16]>::try_into(v) {
|
||||||
|
Ok(v) => Ok(DbId(Ulid(u128::from_be_bytes(v)))),
|
||||||
|
Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
let len = v.len();
|
||||||
|
match std::convert::TryInto::<[u8; 16]>::try_into(v) {
|
||||||
|
Ok(v) => Ok(DbId(Ulid(u128::from_be_bytes(v)))),
|
||||||
|
Err(_) => Err(serde::de::Error::invalid_length(len, &self)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
match Ulid::from_string(&v) {
|
||||||
|
Ok(v) => Ok(DbId(v)),
|
||||||
|
Err(_) => Err(serde::de::Error::invalid_value(
|
||||||
|
serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")),
|
||||||
|
&self,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
match Ulid::from_string(v) {
|
||||||
|
Ok(v) => Ok(DbId(v)),
|
||||||
|
Err(_) => Err(serde::de::Error::invalid_value(
|
||||||
|
serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")),
|
||||||
|
&self,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::SeqAccess<'de>,
|
||||||
|
{
|
||||||
|
let mut bytes = [0u8; 16];
|
||||||
|
let size = seq.size_hint().unwrap_or(0);
|
||||||
|
let mut count = 0;
|
||||||
|
while let Some(val) = seq.next_element()? {
|
||||||
|
if count >= 16 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bytes[count] = val;
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
if count != 16 || size > 16 {
|
||||||
|
let sz = if count < 16 { count } else { size };
|
||||||
|
Err(serde::de::Error::invalid_length(sz, &self))
|
||||||
|
} else {
|
||||||
|
let id = u128::from_ne_bytes(bytes);
|
||||||
|
Ok(id.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for DbId {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
deserializer.deserialize_bytes(DbIdVisitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//-************************************************************************
|
||||||
|
// serialization tests
|
||||||
|
//-************************************************************************
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use serde_test::{assert_tokens, Token};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ser_de() {
|
||||||
|
let bytes: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
|
||||||
|
let be_num = u128::from_be_bytes(bytes);
|
||||||
|
let le_ulid = Ulid(be_num);
|
||||||
|
let le_id = DbId(le_ulid);
|
||||||
|
|
||||||
|
assert_tokens(
|
||||||
|
&le_id,
|
||||||
|
&[Token::Bytes(&[
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||||
|
])],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
143
src/ids.rs
143
src/ids.rs
|
@ -1,143 +0,0 @@
|
||||||
use std::{borrow::Cow, fmt::Display};
|
|
||||||
|
|
||||||
use serde::{de::Visitor, Deserialize, Serialize};
|
|
||||||
use sqlx::{
|
|
||||||
encode::IsNull,
|
|
||||||
sqlite::{SqliteArgumentValue, SqliteValueRef},
|
|
||||||
Decode, Encode, Sqlite,
|
|
||||||
};
|
|
||||||
use ulid::Ulid;
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
|
||||||
pub struct DbId(pub Ulid);
|
|
||||||
|
|
||||||
impl From<Ulid> for DbId {
|
|
||||||
fn from(value: Ulid) -> Self {
|
|
||||||
DbId(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl sqlx::Type<sqlx::Sqlite> for DbId {
|
|
||||||
fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
|
|
||||||
<&[u8] as sqlx::Type<sqlx::Sqlite>>::type_info()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'q> Encode<'q, Sqlite> for DbId {
|
|
||||||
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
|
|
||||||
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
|
|
||||||
IsNull::No
|
|
||||||
}
|
|
||||||
|
|
||||||
fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
|
|
||||||
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
|
|
||||||
|
|
||||||
IsNull::No
|
|
||||||
}
|
|
||||||
|
|
||||||
fn produces(&self) -> Option<<Sqlite as sqlx::Database>::TypeInfo> {
|
|
||||||
// `produces` is inherently a hook to allow database drivers to produce
|
|
||||||
// value-dependent type information; if the driver doesn't need this, it
|
|
||||||
// can leave this as `None`
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_hint(&self) -> usize {
|
|
||||||
std::mem::size_of_val(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'r> Decode<'r, Sqlite> for DbId {
|
|
||||||
fn decode(value: SqliteValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
|
|
||||||
let bytes = <&[u8] as Decode<Sqlite>>::decode(value)?;
|
|
||||||
let bytes: [u8; 16] = bytes.try_into().unwrap_or_default();
|
|
||||||
let id: Ulid = u128::from_ne_bytes(bytes).into();
|
|
||||||
Ok(id.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for DbId {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.0.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DbId {
|
|
||||||
pub fn bytes(&self) -> [u8; 16] {
|
|
||||||
self.0 .0.to_ne_bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_nil(&self) -> bool {
|
|
||||||
self.0.is_nil()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new() -> Self {
|
|
||||||
let id = Ulid::new();
|
|
||||||
Self(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Serialize for DbId {
|
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: serde::Serializer,
|
|
||||||
{
|
|
||||||
serializer.serialize_bytes(&self.0 .0.to_ne_bytes())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct IdVisitor;
|
|
||||||
|
|
||||||
impl<'de> Visitor<'de> for IdVisitor {
|
|
||||||
type Value = DbId;
|
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
formatter.write_str("a 128-bit number")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
Ok(DbId(Ulid(v as u128)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
Ok(DbId(Ulid(v)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
match std::convert::TryInto::<[u8; 16]>::try_into(v) {
|
|
||||||
Ok(v) => Ok(DbId(Ulid(u128::from_ne_bytes(v)))),
|
|
||||||
Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
match Ulid::from_string(&v) {
|
|
||||||
Ok(v) => Ok(DbId(v)),
|
|
||||||
Err(_) => Err(serde::de::Error::invalid_value(
|
|
||||||
serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")),
|
|
||||||
&self,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'de> Deserialize<'de> for DbId {
|
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
||||||
where
|
|
||||||
D: serde::Deserializer<'de>,
|
|
||||||
{
|
|
||||||
deserializer.deserialize_bytes(IdVisitor)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -9,8 +9,8 @@ pub use db::get_db_pool;
|
||||||
|
|
||||||
// everything else is private to the crate
|
// everything else is private to the crate
|
||||||
mod db;
|
mod db;
|
||||||
|
mod db_id;
|
||||||
mod generic_handlers;
|
mod generic_handlers;
|
||||||
mod ids;
|
|
||||||
mod login;
|
mod login;
|
||||||
mod signup;
|
mod signup;
|
||||||
mod templates;
|
mod templates;
|
||||||
|
@ -19,7 +19,7 @@ mod util;
|
||||||
mod watches;
|
mod watches;
|
||||||
|
|
||||||
// things we want in the crate namespace
|
// things we want in the crate namespace
|
||||||
use ids::DbId;
|
use db_id::DbId;
|
||||||
use optional_optional_user::OptionalOptionalUser;
|
use optional_optional_user::OptionalOptionalUser;
|
||||||
use templates::*;
|
use templates::*;
|
||||||
use users::User;
|
use users::User;
|
||||||
|
|
|
@ -206,8 +206,8 @@ mod test {
|
||||||
}
|
}
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
let idx = s.get("/").await;
|
let main_page = s.get("/").await;
|
||||||
let body = std::str::from_utf8(idx.bytes()).unwrap();
|
let body = std::str::from_utf8(main_page.bytes()).unwrap();
|
||||||
assert_eq!(&logged_in, body);
|
assert_eq!(&logged_in, body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,11 @@ use axum::{
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
use sqlx::{query_as, SqlitePool};
|
use sqlx::{query_as, SqlitePool};
|
||||||
use ulid::Ulid;
|
|
||||||
use unicode_segmentation::UnicodeSegmentation;
|
use unicode_segmentation::UnicodeSegmentation;
|
||||||
|
|
||||||
use crate::{CreateUser, CreateUserSuccess, DbId, User};
|
use crate::{util::empty_string_as_none, CreateUserSuccess, DbId, SignupPage, User};
|
||||||
|
|
||||||
pub(crate) const CREATE_QUERY: &str =
|
pub(crate) const CREATE_QUERY: &str =
|
||||||
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
|
"insert into witches (id, username, displayname, email, pwhash) values ($1, $2, $3, $4, $5)";
|
||||||
|
@ -51,13 +51,24 @@ pub enum CreateUserErrorKind {
|
||||||
UnknownDBError,
|
UnknownDBError,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct SignupForm {
|
||||||
|
pub username: String,
|
||||||
|
#[serde(default, deserialize_with = "empty_string_as_none")]
|
||||||
|
pub displayname: Option<String>,
|
||||||
|
#[serde(default, deserialize_with = "empty_string_as_none")]
|
||||||
|
pub email: Option<String>,
|
||||||
|
pub password: String,
|
||||||
|
pub pw_verify: String,
|
||||||
|
}
|
||||||
|
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
// User creation route handlers
|
// User creation route handlers
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
|
|
||||||
/// Get Handler: displays the form to create a user
|
/// Get Handler: displays the form to create a user
|
||||||
pub async fn get_create_user() -> CreateUser {
|
pub async fn get_create_user() -> SignupPage {
|
||||||
CreateUser::default()
|
SignupPage::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Post Handler: validates form values and calls the actual, private user
|
/// Post Handler: validates form values and calls the actual, private user
|
||||||
|
@ -65,7 +76,7 @@ pub async fn get_create_user() -> CreateUser {
|
||||||
#[axum::debug_handler]
|
#[axum::debug_handler]
|
||||||
pub async fn post_create_user(
|
pub async fn post_create_user(
|
||||||
State(pool): State<SqlitePool>,
|
State(pool): State<SqlitePool>,
|
||||||
Form(signup): Form<CreateUser>,
|
Form(signup): Form<SignupForm>,
|
||||||
) -> Result<impl IntoResponse, CreateUserError> {
|
) -> Result<impl IntoResponse, CreateUserError> {
|
||||||
use crate::util::validate_optional_length;
|
use crate::util::validate_optional_length;
|
||||||
let username = signup.username.trim();
|
let username = signup.username.trim();
|
||||||
|
@ -114,10 +125,10 @@ pub async fn get_signup_success(
|
||||||
State(pool): State<SqlitePool>,
|
State(pool): State<SqlitePool>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let id = id.trim();
|
let id = id.trim();
|
||||||
|
let id = DbId::from_str(id).unwrap_or_default();
|
||||||
let user: User = {
|
let user: User = {
|
||||||
let id: DbId = Ulid::from_string(id).unwrap_or_default().into();
|
|
||||||
query_as(ID_QUERY)
|
query_as(ID_QUERY)
|
||||||
.bind(id.bytes().as_slice())
|
.bind(id)
|
||||||
.fetch_one(&pool)
|
.fetch_one(&pool)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
|
@ -125,7 +136,7 @@ pub async fn get_signup_success(
|
||||||
|
|
||||||
let mut resp = CreateUserSuccess(user.clone()).into_response();
|
let mut resp = CreateUserSuccess(user.clone()).into_response();
|
||||||
|
|
||||||
if user.username.is_empty() || id.is_empty() {
|
if user.username.is_empty() || id.is_nil() {
|
||||||
// redirect to front page if we got here without a valid witch ID
|
// redirect to front page if we got here without a valid witch ID
|
||||||
*resp.status_mut() = StatusCode::SEE_OTHER;
|
*resp.status_mut() = StatusCode::SEE_OTHER;
|
||||||
resp.headers_mut().insert("Location", "/".parse().unwrap());
|
resp.headers_mut().insert("Location", "/".parse().unwrap());
|
||||||
|
@ -154,9 +165,8 @@ pub(crate) async fn create_user(
|
||||||
.unwrap() // safe to unwrap, we know the salt is valid
|
.unwrap() // safe to unwrap, we know the salt is valid
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
let bytes = &id.bytes();
|
|
||||||
let query = sqlx::query(CREATE_QUERY)
|
let query = sqlx::query(CREATE_QUERY)
|
||||||
.bind(bytes.as_slice())
|
.bind(id)
|
||||||
.bind(username)
|
.bind(username)
|
||||||
.bind(displayname)
|
.bind(displayname)
|
||||||
.bind(email)
|
.bind(email)
|
||||||
|
@ -203,7 +213,7 @@ mod test {
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
db::get_db_pool,
|
db::get_db_pool,
|
||||||
templates::{CreateUser, CreateUserSuccess},
|
templates::{CreateUserSuccess, SignupPage},
|
||||||
test_utils::{get_test_user, insert_user, massage, server_with_pool, FORM_CONTENT_TYPE},
|
test_utils::{get_test_user, insert_user, massage, server_with_pool, FORM_CONTENT_TYPE},
|
||||||
User,
|
User,
|
||||||
};
|
};
|
||||||
|
@ -237,7 +247,7 @@ mod test {
|
||||||
|
|
||||||
let resp = server.get("/signup").await;
|
let resp = server.get("/signup").await;
|
||||||
let body = std::str::from_utf8(resp.bytes()).unwrap();
|
let body = std::str::from_utf8(resp.bytes()).unwrap();
|
||||||
let expected = CreateUser::default().to_string();
|
let expected = SignupPage::default().to_string();
|
||||||
assert_eq!(&expected, body);
|
assert_eq!(&expected, body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ use crate::{OptionalOptionalUser, User};
|
||||||
|
|
||||||
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq, OptionalOptionalUser)]
|
#[derive(Debug, Default, Template, Deserialize, Serialize, PartialEq, Eq, OptionalOptionalUser)]
|
||||||
#[template(path = "signup.html")]
|
#[template(path = "signup.html")]
|
||||||
pub struct CreateUser {
|
pub struct SignupPage {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub displayname: Option<String>,
|
pub displayname: Option<String>,
|
||||||
pub email: Option<String>,
|
pub email: Option<String>,
|
||||||
|
|
|
@ -11,7 +11,7 @@ pub fn get_test_user() -> User {
|
||||||
username: "test_user".to_string(),
|
username: "test_user".to_string(),
|
||||||
// corresponding to a password of "a":
|
// corresponding to a password of "a":
|
||||||
pwhash: "$argon2id$v=19$m=19456,t=2,p=1$GWsCH1w5RYaP9WWmq+xw0g$hmOEqC+MU+vnEk3bOdkoE+z01mOmmOeX08XyPyjqua8".to_string(),
|
pwhash: "$argon2id$v=19$m=19456,t=2,p=1$GWsCH1w5RYaP9WWmq+xw0g$hmOEqC+MU+vnEk3bOdkoE+z01mOmmOeX08XyPyjqua8".to_string(),
|
||||||
id: DbId::default(),
|
id: DbId::from_str("00041061050R3GG28A1C60T3GF").unwrap(),
|
||||||
displayname: Some("Test User".to_string()),
|
displayname: Some("Test User".to_string()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ use crate::{AuthContext, DbId};
|
||||||
const USERNAME_QUERY: &str = "select * from witches where username = $1";
|
const USERNAME_QUERY: &str = "select * from witches where username = $1";
|
||||||
const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1";
|
const LAST_SEEN_QUERY: &str = "update witches set last_seen = (select unixepoch()) where id = $1";
|
||||||
|
|
||||||
#[derive(Default, Clone, PartialEq, Eq, sqlx::FromRow, Serialize, Deserialize)]
|
#[derive(Default, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
pub id: DbId,
|
pub id: DbId,
|
||||||
pub username: String,
|
pub username: String,
|
||||||
|
|
16
src/util.rs
16
src/util.rs
|
@ -19,3 +19,19 @@ pub fn validate_optional_length<E: Error>(
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Serde deserialization decorator to map empty Strings to None,
|
||||||
|
pub fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
T: std::str::FromStr,
|
||||||
|
T::Err: std::fmt::Display,
|
||||||
|
{
|
||||||
|
let opt = <Option<String> as serde::Deserialize>::deserialize(de)?;
|
||||||
|
match opt.as_deref() {
|
||||||
|
None | Some("") => Ok(None),
|
||||||
|
Some(s) => std::str::FromStr::from_str(s)
|
||||||
|
.map_err(serde::de::Error::custom)
|
||||||
|
.map(Some),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,12 +3,11 @@ use axum::{
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{IntoResponse, Redirect, Response},
|
response::{IntoResponse, Redirect, Response},
|
||||||
};
|
};
|
||||||
use serde::{de, Deserialize, Deserializer};
|
use serde::Deserialize;
|
||||||
use sqlx::{query, query_as, SqlitePool};
|
use sqlx::{query, query_as, SqlitePool};
|
||||||
use ulid::Ulid;
|
|
||||||
|
|
||||||
use super::templates::{AddNewWatchPage, GetWatchPage, SearchWatchesPage};
|
use super::templates::{AddNewWatchPage, GetWatchPage, SearchWatchesPage};
|
||||||
use crate::{ids::DbId, AuthContext, MyWatchesPage, ShowKind, Watch};
|
use crate::{db_id::DbId, util::empty_string_as_none, AuthContext, MyWatchesPage, ShowKind, Watch};
|
||||||
|
|
||||||
//-************************************************************************
|
//-************************************************************************
|
||||||
// Constants
|
// Constants
|
||||||
|
@ -178,7 +177,7 @@ pub async fn get_watch(
|
||||||
"".to_string()
|
"".to_string()
|
||||||
};
|
};
|
||||||
let id = id.trim();
|
let id = id.trim();
|
||||||
let id: DbId = Ulid::from_string(id).unwrap_or_default().into();
|
let id = DbId::from_str(id).unwrap_or_default();
|
||||||
let watch: Option<Watch> = query_as(GET_WATCH_QUERY)
|
let watch: Option<Watch> = query_as(GET_WATCH_QUERY)
|
||||||
.bind(id)
|
.bind(id)
|
||||||
.fetch_one(&pool)
|
.fetch_one(&pool)
|
||||||
|
@ -233,23 +232,3 @@ pub async fn get_search_watch(
|
||||||
search,
|
search,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//-************************************************************************
|
|
||||||
// helper fns
|
|
||||||
//-************************************************************************
|
|
||||||
|
|
||||||
/// Serde deserialization decorator to map empty Strings to None,
|
|
||||||
fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
T: std::str::FromStr,
|
|
||||||
T::Err: std::fmt::Display,
|
|
||||||
{
|
|
||||||
let opt = Option::<String>::deserialize(de)?;
|
|
||||||
match opt.as_deref() {
|
|
||||||
None | Some("") => Ok(None),
|
|
||||||
Some(s) => std::str::FromStr::from_str(s)
|
|
||||||
.map_err(de::Error::custom)
|
|
||||||
.map(Some),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue