252 lines
6.8 KiB
Rust
252 lines
6.8 KiB
Rust
|
use std::{
|
||
|
borrow::Cow,
|
||
|
fmt::{Debug, Display},
|
||
|
};
|
||
|
|
||
|
use serde::{de::Visitor, Deserialize, Serialize};
|
||
|
use sqlx::{
|
||
|
encode::IsNull,
|
||
|
sqlite::{SqliteArgumentValue, SqliteValueRef},
|
||
|
Decode, Encode, Sqlite,
|
||
|
};
|
||
|
use ulid::Ulid;
|
||
|
|
||
|
#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||
|
pub struct DbId(pub Ulid);
|
||
|
|
||
|
impl DbId {
|
||
|
pub fn bytes(&self) -> [u8; 16] {
|
||
|
self.as_be_bytes()
|
||
|
}
|
||
|
|
||
|
pub fn as_be_bytes(&self) -> [u8; 16] {
|
||
|
self.0 .0.to_be_bytes()
|
||
|
}
|
||
|
|
||
|
pub fn is_nil(&self) -> bool {
|
||
|
self.0.is_nil()
|
||
|
}
|
||
|
|
||
|
pub fn new() -> Self {
|
||
|
let id = Ulid::new();
|
||
|
Self(id)
|
||
|
}
|
||
|
|
||
|
pub fn from_str(s: &str) -> Result<Self, ulid::DecodeError> {
|
||
|
let id = Ulid::from_string(s)?;
|
||
|
Ok(id.into())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-************************************************************************
|
||
|
// standard trait impls
|
||
|
//-************************************************************************
|
||
|
|
||
|
impl Display for DbId {
|
||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
write!(f, "{}", self.0.to_string())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Debug for DbId {
|
||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
f.debug_tuple("DbId").field(&self.bytes()).finish()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl From<Ulid> for DbId {
|
||
|
fn from(value: Ulid) -> Self {
|
||
|
DbId(value)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl From<u128> for DbId {
|
||
|
fn from(value: u128) -> Self {
|
||
|
DbId(value.into())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-************************************************************************
|
||
|
// sqlx traits for going in and out of the db
|
||
|
//-************************************************************************
|
||
|
|
||
|
impl sqlx::Type<sqlx::Sqlite> for DbId {
|
||
|
fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
|
||
|
<&[u8] as sqlx::Type<sqlx::Sqlite>>::type_info()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// sqlx traits for marshalling in and out
|
||
|
impl<'q> Encode<'q, Sqlite> for DbId {
|
||
|
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
|
||
|
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
|
||
|
IsNull::No
|
||
|
}
|
||
|
|
||
|
fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
|
||
|
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.bytes().to_vec())));
|
||
|
IsNull::No
|
||
|
}
|
||
|
|
||
|
fn produces(&self) -> Option<<Sqlite as sqlx::Database>::TypeInfo> {
|
||
|
// `produces` is inherently a hook to allow database drivers to produce
|
||
|
// value-dependent type information; if the driver doesn't need this, it
|
||
|
// can leave this as `None`
|
||
|
None
|
||
|
}
|
||
|
|
||
|
fn size_hint(&self) -> usize {
|
||
|
std::mem::size_of_val(self)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl<'r> Decode<'r, Sqlite> for DbId {
|
||
|
fn decode(value: SqliteValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
|
||
|
let bytes = <&[u8] as Decode<Sqlite>>::decode(value)?;
|
||
|
let bytes: [u8; 16] = bytes.try_into().unwrap_or_default();
|
||
|
let id: Ulid = u128::from_ne_bytes(bytes).into();
|
||
|
Ok(id.into())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-************************************************************************
|
||
|
// serde traits
|
||
|
//-************************************************************************
|
||
|
impl Serialize for DbId {
|
||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||
|
where
|
||
|
S: serde::Serializer,
|
||
|
{
|
||
|
serializer.serialize_bytes(&self.bytes())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
struct DbIdVisitor;
|
||
|
|
||
|
impl<'de> Visitor<'de> for DbIdVisitor {
|
||
|
type Value = DbId;
|
||
|
|
||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||
|
formatter.write_str("a 128-bit number")
|
||
|
}
|
||
|
|
||
|
fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
|
||
|
where
|
||
|
E: serde::de::Error,
|
||
|
{
|
||
|
Ok(DbId(Ulid(v as u128)))
|
||
|
}
|
||
|
|
||
|
fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
|
||
|
where
|
||
|
E: serde::de::Error,
|
||
|
{
|
||
|
Ok(DbId(Ulid(v)))
|
||
|
}
|
||
|
|
||
|
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
|
||
|
where
|
||
|
E: serde::de::Error,
|
||
|
{
|
||
|
match std::convert::TryInto::<[u8; 16]>::try_into(v) {
|
||
|
Ok(v) => Ok(DbId(Ulid(u128::from_be_bytes(v)))),
|
||
|
Err(_) => Err(serde::de::Error::invalid_length(v.len(), &self)),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
|
||
|
where
|
||
|
E: serde::de::Error,
|
||
|
{
|
||
|
let len = v.len();
|
||
|
match std::convert::TryInto::<[u8; 16]>::try_into(v) {
|
||
|
Ok(v) => Ok(DbId(Ulid(u128::from_be_bytes(v)))),
|
||
|
Err(_) => Err(serde::de::Error::invalid_length(len, &self)),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
|
||
|
where
|
||
|
E: serde::de::Error,
|
||
|
{
|
||
|
match Ulid::from_string(&v) {
|
||
|
Ok(v) => Ok(DbId(v)),
|
||
|
Err(_) => Err(serde::de::Error::invalid_value(
|
||
|
serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")),
|
||
|
&self,
|
||
|
)),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||
|
where
|
||
|
E: serde::de::Error,
|
||
|
{
|
||
|
match Ulid::from_string(v) {
|
||
|
Ok(v) => Ok(DbId(v)),
|
||
|
Err(_) => Err(serde::de::Error::invalid_value(
|
||
|
serde::de::Unexpected::Str(&format!("could not convert {v} to a ULID")),
|
||
|
&self,
|
||
|
)),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||
|
where
|
||
|
A: serde::de::SeqAccess<'de>,
|
||
|
{
|
||
|
let mut bytes = [0u8; 16];
|
||
|
let size = seq.size_hint().unwrap_or(0);
|
||
|
let mut count = 0;
|
||
|
while let Some(val) = seq.next_element()? {
|
||
|
if count >= 16 {
|
||
|
break;
|
||
|
}
|
||
|
bytes[count] = val;
|
||
|
count += 1;
|
||
|
}
|
||
|
if count != 16 || size > 16 {
|
||
|
let sz = if count < 16 { count } else { size };
|
||
|
Err(serde::de::Error::invalid_length(sz, &self))
|
||
|
} else {
|
||
|
let id = u128::from_ne_bytes(bytes);
|
||
|
Ok(id.into())
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl<'de> Deserialize<'de> for DbId {
|
||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||
|
where
|
||
|
D: serde::Deserializer<'de>,
|
||
|
{
|
||
|
deserializer.deserialize_bytes(DbIdVisitor)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-************************************************************************
|
||
|
// serialization tests
|
||
|
//-************************************************************************
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod test {
|
||
|
use serde_test::{assert_tokens, Token};
|
||
|
|
||
|
use super::*;
|
||
|
|
||
|
#[test]
|
||
|
fn test_ser_de() {
|
||
|
let bytes: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
|
||
|
let be_num = u128::from_be_bytes(bytes);
|
||
|
let le_ulid = Ulid(be_num);
|
||
|
let le_id = DbId(le_ulid);
|
||
|
|
||
|
assert_tokens(
|
||
|
&le_id,
|
||
|
&[Token::Bytes(&[
|
||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||
|
])],
|
||
|
);
|
||
|
}
|
||
|
}
|