what2watch/src/ids.rs
2023-06-20 16:32:36 -07:00

143 lines
3.7 KiB
Rust

use std::{borrow::Cow, fmt::Display};
use serde::{de::Visitor, Deserialize, Serialize};
use sqlx::{
encode::IsNull,
sqlite::{SqliteArgumentValue, SqliteValueRef},
Decode, Encode, Sqlite,
};
use ulid::Ulid;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct DbId(pub Ulid);
impl From<Ulid> for DbId {
fn from(value: Ulid) -> Self {
DbId(value)
}
}
impl sqlx::Type<sqlx::Sqlite> for DbId {
fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
<&[u8] as sqlx::Type<sqlx::Sqlite>>::type_info()
}
}
impl<'q> Encode<'q, Sqlite> for DbId {
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
IsNull::No
}
fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
IsNull::No
}
fn produces(&self) -> Option<<Sqlite as sqlx::Database>::TypeInfo> {
// `produces` is inherently a hook to allow database drivers to produce
// value-dependent type information; if the driver doesn't need this, it
// can leave this as `None`
None
}
fn size_hint(&self) -> usize {
std::mem::size_of_val(self)
}
}
impl<'r> Decode<'r, Sqlite> for DbId {
fn decode(value: SqliteValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
let bytes = <&[u8] as Decode<Sqlite>>::decode(value)?;
let bytes: [u8; 16] = bytes.try_into().unwrap_or_default();
let id: Ulid = u128::from_ne_bytes(bytes).into();
Ok(id.into())
}
}
impl Display for DbId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.to_string())
}
}
impl DbId {
pub fn bytes(&self) -> [u8; 16] {
self.0 .0.to_ne_bytes()
}
pub fn is_nil(&self) -> bool {
self.0.is_nil()
}
pub fn new() -> Self {
let id = Ulid::new();
Self(id)
}
}
impl Serialize for DbId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0 .0.to_ne_bytes())
}
}
struct IdVisitor;
impl<'de> Visitor<'de> for IdVisitor {
type Value = DbId;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a 128-bit number")
}
fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(DbId(Ulid(v as u128)))
}
fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(DbId(Ulid(v)))
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match std::convert::TryInto::<[u8; 16]>::try_into(v) {
Ok(v) => Ok(DbId(Ulid(u128::from_ne_bytes(v)))),
Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)),
}
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match Ulid::from_string(&v) {
Ok(v) => Ok(DbId(v)),
Err(_) => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")),
&self,
)),
}
}
}
impl<'de> Deserialize<'de> for DbId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_bytes(IdVisitor)
}
}