use std::{borrow::Cow, fmt::Display}; use serde::{de::Visitor, Deserialize, Serialize}; use sqlx::{ encode::IsNull, sqlite::{SqliteArgumentValue, SqliteValueRef}, Decode, Encode, Sqlite, }; use ulid::Ulid; #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct DbId(pub Ulid); impl From for DbId { fn from(value: Ulid) -> Self { DbId(value) } } impl sqlx::Type for DbId { fn type_info() -> ::TypeInfo { <&[u8] as sqlx::Type>::type_info() } } impl<'q> Encode<'q, Sqlite> for DbId { fn encode(self, args: &mut Vec>) -> IsNull { args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec()))); IsNull::No } fn encode_by_ref(&self, args: &mut Vec>) -> IsNull { args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec()))); IsNull::No } fn produces(&self) -> Option<::TypeInfo> { // `produces` is inherently a hook to allow database drivers to produce // value-dependent type information; if the driver doesn't need this, it // can leave this as `None` None } fn size_hint(&self) -> usize { std::mem::size_of_val(self) } } impl<'r> Decode<'r, Sqlite> for DbId { fn decode(value: SqliteValueRef<'r>) -> Result { let bytes = <&[u8] as Decode>::decode(value)?; let bytes: [u8; 16] = bytes.try_into().unwrap_or_default(); let id: Ulid = u128::from_ne_bytes(bytes).into(); Ok(id.into()) } } impl Display for DbId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.to_string()) } } impl DbId { pub fn bytes(&self) -> [u8; 16] { self.0 .0.to_ne_bytes() } pub fn is_nil(&self) -> bool { self.0.is_nil() } pub fn new() -> Self { let id = Ulid::new(); Self(id) } } impl Serialize for DbId { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_bytes(&self.0 .0.to_ne_bytes()) } } struct IdVisitor; impl<'de> Visitor<'de> for IdVisitor { type Value = DbId; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("a 128-bit number") } fn visit_i128(self, v: i128) -> Result where E: serde::de::Error, { Ok(DbId(Ulid(v as u128))) } fn visit_u128(self, v: u128) -> Result where E: serde::de::Error, { Ok(DbId(Ulid(v))) } fn visit_bytes(self, v: &[u8]) -> Result where E: serde::de::Error, { match std::convert::TryInto::<[u8; 16]>::try_into(v) { Ok(v) => Ok(DbId(Ulid(u128::from_ne_bytes(v)))), Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)), } } fn visit_string(self, v: String) -> Result where E: serde::de::Error, { match Ulid::from_string(&v) { Ok(v) => Ok(DbId(v)), Err(_) => Err(serde::de::Error::invalid_value( serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")), &self, )), } } } impl<'de> Deserialize<'de> for DbId { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { deserializer.deserialize_bytes(IdVisitor) } }