use std::{ borrow::Cow, fmt::{Debug, Display}, }; use chrono::Utc; use serde::{de::Visitor, Deserialize, Serialize}; use sqlx::{ encode::IsNull, sqlite::{SqliteArgumentValue, SqliteValueRef}, Decode, Encode, Sqlite, }; use ulid::Ulid; #[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct DbId(pub Ulid); impl DbId { pub fn bytes(&self) -> [u8; 16] { self.to_be_bytes() } pub fn to_be_bytes(self) -> [u8; 16] { self.0 .0.to_be_bytes() } pub fn is_nil(&self) -> bool { self.0.is_nil() } pub fn new() -> Self { Self(Ulid::new()) } pub fn from_string(s: &str) -> Result { let id = Ulid::from_string(s)?; Ok(id.into()) } pub fn as_string(&self) -> String { self.0.to_string() } pub fn created_at(&self) -> chrono::DateTime { self.0.datetime().into() } } //-************************************************************************ // standard trait impls //-************************************************************************ impl Display for DbId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.to_string()) } } impl Debug for DbId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("DbId").field(&self.as_string()).finish() } } impl From for DbId { fn from(value: Ulid) -> Self { DbId(value) } } impl From for DbId { fn from(value: u128) -> Self { DbId(value.into()) } } //-************************************************************************ // sqlx traits for going in and out of the db //-************************************************************************ impl sqlx::Type for DbId { fn type_info() -> ::TypeInfo { <&[u8] as sqlx::Type>::type_info() } } impl<'q> Encode<'q, Sqlite> for DbId { fn encode_by_ref(&self, args: &mut Vec>) -> IsNull { args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec()))); IsNull::No } } impl Decode<'_, Sqlite> for DbId { fn decode(value: SqliteValueRef<'_>) -> Result { let bytes = <&[u8] as Decode>::decode(value)?; let bytes: [u8; 16] = bytes.try_into().unwrap_or_default(); Ok(u128::from_be_bytes(bytes).into()) } } //-************************************************************************ // serde traits //-************************************************************************ impl Serialize for DbId { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_bytes(&self.bytes()) } } struct DbIdVisitor; impl<'de> Visitor<'de> for DbIdVisitor { type Value = DbId; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("16 bytes") } fn visit_bytes(self, v: &[u8]) -> Result where E: serde::de::Error, { match std::convert::TryInto::<[u8; 16]>::try_into(v) { Ok(v) => Ok(u128::from_be_bytes(v).into()), Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)), } } fn visit_byte_buf(self, v: Vec) -> Result where E: serde::de::Error, { let len = v.len(); match std::convert::TryInto::<[u8; 16]>::try_into(v) { Ok(v) => Ok(u128::from_be_bytes(v).into()), Err(_) => Err(serde::de::Error::invalid_length(len, &self)), } } fn visit_seq(self, mut seq: A) -> Result where A: serde::de::SeqAccess<'de>, { let mut raw_bytes_from_db = [0u8; 16]; let size = seq.size_hint().unwrap_or(0); let mut count = 0; while let Some(val) = seq.next_element()? { if count > 15 { break; } raw_bytes_from_db[count] = val; count += 1; } if count != 16 || size > 16 { let sz = if count < 16 { count } else { size }; Err(serde::de::Error::invalid_length(sz, &self)) } else { Ok(u128::from_be_bytes(raw_bytes_from_db).into()) } } } impl<'de> Deserialize<'de> for DbId { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { deserializer.deserialize_bytes(DbIdVisitor) } } //-************************************************************************ // serialization tests //-************************************************************************ #[cfg(test)] mod test { use serde_test::{assert_tokens, Token}; use super::*; #[test] fn test_ser_de() { let bytes: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; let id: DbId = u128::from_be_bytes(bytes).into(); assert_tokens( &id, &[Token::Bytes(&[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, ])], ); } }