start on centralizing db access

This commit is contained in:
joe 2025-12-28 14:58:26 -08:00
parent f65bae9014
commit 413dc6ab9a
2 changed files with 137 additions and 80 deletions

129
src/db.rs Normal file
View file

@ -0,0 +1,129 @@
const MAX_CONNS: u32 = 200;
const MIN_CONNS: u32 = 5;
const TIMEOUT: u64 = 2000; // in milliseconds
use std::time::Duration;
use sqlx::{
Sqlite, SqlitePool,
query::Query,
sqlite::{
SqliteArguments, SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteRow,
},
};
use crate::BlogdorTheAggregator;
pub enum DbAction<'q> {
Execute(Query<'q, Sqlite, SqliteArguments<'q>>),
FetchOne(Query<'q, Sqlite, SqliteArguments<'q>>),
FetchMany(Query<'q, Sqlite, SqliteArguments<'q>>),
FetchOptional(Query<'q, Sqlite, SqliteArguments<'q>>),
}
pub enum DbValue {
None,
Optional(Option<SqliteRow>),
One(SqliteRow),
Many(Vec<SqliteRow>),
}
impl BlogdorTheAggregator {
pub async fn close_db(&self) {
self.db.close().await;
}
pub async fn db_action<'q, T>(&self, query: DbAction<'q>) -> Result<DbValue, String> {
match query {
DbAction::Execute(q) => {
q.execute(&self.db).await.map_err(|e| format!("{e}"))?;
Ok(DbValue::None)
}
DbAction::FetchOne(q) => {
let r = q.fetch_one(&self.db).await.map_err(|e| format!("{e}"))?;
Ok(DbValue::One(r))
}
DbAction::FetchMany(q) => {
let r = q.fetch_all(&self.db).await.map_err(|e| format!("{e}"))?;
Ok(DbValue::Many(r))
}
DbAction::FetchOptional(q) => {
let r = q
.fetch_optional(&self.db)
.await
.map_err(|e| format!("{e}"))?;
Ok(DbValue::Optional(r))
}
}
}
}
pub async fn get_db_pool() -> SqlitePool {
let db_filename = {
std::env::var("DATABASE_FILE").unwrap_or_else(|_| {
#[cfg(not(test))]
{
tracing::info!("connecting to default db file");
"blogdor.db".to_string()
}
#[cfg(test)]
{
use rand::RngCore;
let mut rng = rand::rng();
let id = rng.next_u64();
// see https://www.sqlite.org/inmemorydb.html for meaning of the string;
// it allows each separate test to have its own dedicated memory-backed db that
// will live as long as the whole process
format!("file:testdb-{id}?mode=memory&cache=shared")
}
})
};
tracing::info!("Connecting to DB at {db_filename}");
let conn_opts = SqliteConnectOptions::new()
.foreign_keys(true)
.journal_mode(SqliteJournalMode::Wal)
.synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
.filename(&db_filename)
.busy_timeout(Duration::from_secs(TIMEOUT))
.pragma("temp_store", "memory")
.create_if_missing(true)
.optimize_on_close(true, None)
.pragma("mmap_size", "3000000000");
let pool = SqlitePoolOptions::new()
.max_connections(MAX_CONNS)
.min_connections(MIN_CONNS)
.idle_timeout(Some(Duration::from_secs(3)))
.max_lifetime(Some(Duration::from_secs(3600)))
.connect_with(conn_opts)
.await
.expect("could not get sqlite pool");
sqlx::migrate!()
.run(&pool)
.await
.expect("could not run migrations");
tracing::info!("Ran migrations");
pool
}
//-************************************************************************
// Tests for `db` module.
//-************************************************************************
#[cfg(test)]
mod tests {
#[tokio::test]
async fn it_migrates_the_db() {
let db = super::get_db_pool().await;
let r = sqlx::query!("select count(*) as count from feeds")
.fetch_one(&db)
.await;
assert!(r.is_ok());
}
}

View file

@ -5,18 +5,16 @@ use reqwest::{Client, Response, StatusCode};
use server::ServerState;
use sqlx::{
SqlitePool,
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
types::chrono::{DateTime, Utc},
};
use tokio::{sync::mpsc::UnboundedSender, task::JoinSet};
use tokio_util::{bytes::Buf, sync::CancellationToken};
use unicode_segmentation::UnicodeSegmentation;
pub mod server;
mod db;
use db::DbAction;
const MAX_CONNS: u32 = 200;
const MIN_CONNS: u32 = 5;
const TIMEOUT: u64 = 2000; // in milliseconds
pub mod server;
const ZULIP_INTERVAL: Duration = Duration::from_millis(250);
const ZULIP_MESSAGE_CUTOFF: usize = 700;
@ -25,6 +23,10 @@ const LAST_FETCHED: DateTime<Utc> = DateTime::from_timestamp_nanos(0);
const STALE_FETCH_THRESHOLD: Duration = Duration::from_hours(24);
const ADD_FEED_QUERY: &str = "";
const ACTIVE_FEEDS_QUERY: &str = "select id, url from feeds where active = true";
const STALE_FEEDS_QUERY: &str = "select id, url, added_by, created_at from feeds";
pub struct BlogdorTheAggregator {
db: SqlitePool,
client: reqwest::Client,
@ -82,7 +84,7 @@ enum MessageType {
impl BlogdorTheAggregator {
pub async fn new() -> Self {
let db = get_db_pool().await;
let db = db::get_db_pool().await;
let client = reqwest::Client::new();
let cancel = CancellationToken::new();
let endpoint = std::env::var("ZULIP_URL").expect("ZULIP_URL must be set");
@ -306,10 +308,6 @@ impl BlogdorTheAggregator {
}
}
pub async fn close_db(&self) {
self.db.close().await;
}
async fn send_zulip_message<'s>(&'s self, msg: &ZulipMessage<'s>) -> Result<Response, String> {
let msg = serde_urlencoded::to_string(msg).expect("serialize msg");
self.client
@ -416,73 +414,3 @@ async fn fetch_and_parse_feed(url: &str, client: &Client) -> Result<feed_rs::mod
parse(feed.reader()).map_err(|e| format!("could not parse feed from {url}, got {e}"))
}
async fn get_db_pool() -> SqlitePool {
let db_filename = {
std::env::var("DATABASE_FILE").unwrap_or_else(|_| {
#[cfg(not(test))]
{
tracing::info!("connecting to default db file");
"blogdor.db".to_string()
}
#[cfg(test)]
{
use rand::RngCore;
let mut rng = rand::rng();
let id = rng.next_u64();
// see https://www.sqlite.org/inmemorydb.html for meaning of the string;
// it allows each separate test to have its own dedicated memory-backed db that
// will live as long as the whole process
format!("file:testdb-{id}?mode=memory&cache=shared")
}
})
};
tracing::info!("Connecting to DB at {db_filename}");
let conn_opts = SqliteConnectOptions::new()
.foreign_keys(true)
.journal_mode(SqliteJournalMode::Wal)
.synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
.filename(&db_filename)
.busy_timeout(Duration::from_secs(TIMEOUT))
.pragma("temp_store", "memory")
.create_if_missing(true)
.optimize_on_close(true, None)
.pragma("mmap_size", "3000000000");
let pool = SqlitePoolOptions::new()
.max_connections(MAX_CONNS)
.min_connections(MIN_CONNS)
.idle_timeout(Some(Duration::from_secs(3)))
.max_lifetime(Some(Duration::from_secs(3600)))
.connect_with(conn_opts)
.await
.expect("could not get sqlite pool");
sqlx::migrate!()
.run(&pool)
.await
.expect("could not run migrations");
tracing::info!("Ran migrations");
pool
}
//-************************************************************************
// Tests for `db` module.
//-************************************************************************
#[cfg(test)]
mod tests {
#[tokio::test]
async fn it_migrates_the_db() {
let db = super::get_db_pool().await;
let r = sqlx::query!("select count(*) as count from feeds")
.fetch_one(&db)
.await;
assert!(r.is_ok());
}
}