From 61520964ec0c56b45aba8ad36521fa6bb551b5db Mon Sep 17 00:00:00 2001 From: Joe Ardent Date: Wed, 10 May 2023 12:08:03 -0700 Subject: [PATCH] break out main code into modules --- src/db.rs | 32 +++++++++++++++++ src/handlers.rs | 55 +++++++++++++++++++++++++++++ src/main.rs | 94 +++++-------------------------------------------- 3 files changed, 96 insertions(+), 85 deletions(-) create mode 100644 src/db.rs create mode 100644 src/handlers.rs diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..c9a55d9 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,32 @@ +use std::time::Duration; + +use sqlx::{ + sqlite::{SqliteConnectOptions, SqlitePoolOptions}, + SqlitePool, +}; + +const MAX_CONNS: u32 = 100; +const TIMEOUT: u64 = 5; + +pub async fn get_pool() -> SqlitePool { + let db_filename = { + std::env::var("DATABASE_FILE").unwrap_or_else(|_| { + let home = + std::env::var("HOME").expect("Could not determine $HOME for finding db file"); + format!("{home}/.witch-watch.db") + }) + }; + + let conn_opts = SqliteConnectOptions::new() + .foreign_keys(true) + .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) + .filename(&db_filename); + + // setup connection pool + SqlitePoolOptions::new() + .max_connections(MAX_CONNS) + .connect_timeout(Duration::from_secs(TIMEOUT)) + .connect_with(conn_opts) + .await + .expect("can't connect to database") +} diff --git a/src/handlers.rs b/src/handlers.rs new file mode 100644 index 0000000..5ba9bd4 --- /dev/null +++ b/src/handlers.rs @@ -0,0 +1,55 @@ +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts, State}, + http::{request::Parts, StatusCode}, +}; +use sqlx::SqlitePool; + +pub async fn using_connection_pool_extractor( + State(pool): State, +) -> Result { + sqlx::query_scalar("select 'hello world from sqlite get'") + .fetch_one(&pool) + .await + .map_err(internal_error) +} + +// we can also write a custom extractor that grabs a connection from the pool +// which setup is appropriate depends on your application +pub struct DatabaseConnection(sqlx::pool::PoolConnection); + +#[async_trait] +impl FromRequestParts for DatabaseConnection +where + SqlitePool: FromRef, + S: Send + Sync, +{ + type Rejection = (StatusCode, String); + + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { + let pool = SqlitePool::from_ref(state); + + let conn = pool.acquire().await.map_err(internal_error)?; + + Ok(Self(conn)) + } +} + +pub async fn using_connection_extractor( + DatabaseConnection(conn): DatabaseConnection, +) -> Result { + let mut conn = conn; + sqlx::query_scalar("select 'hello world from sqlite post'") + .fetch_one(&mut conn) + .await + .map_err(internal_error) +} + +/// Utility function for mapping any error into a `500 Internal Server Error` +/// response. +fn internal_error(err: E) -> (StatusCode, String) +where + E: std::error::Error, +{ + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) +} diff --git a/src/main.rs b/src/main.rs index 7a669ce..c464c5a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,48 +1,25 @@ -use std::{net::SocketAddr, time::Duration}; +use std::net::SocketAddr; -use axum::{ - async_trait, - extract::{FromRef, FromRequestParts, State}, - http::{request::Parts, StatusCode}, - routing::get, - Router, -}; -use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; -// use tokio::net::TcpListener; +use axum::{routing::get, Router}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +mod db; +mod handlers; + #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "ww_main=debug".into()), + .unwrap_or_else(|_| "witch_watch=debug,axum::routing=info".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); - let db_filename = { - std::env::var("DATABASE_FILE").unwrap_or_else(|_| { - let home = - std::env::var("HOME").expect("Could not determine $HOME for finding db file"); - format!("{home}/.witch-watch.db") - }) - }; - - let conn_opts = SqliteConnectOptions::new() - .foreign_keys(true) - .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) - .filename(&db_filename); - - // setup connection pool - let pool = SqlitePoolOptions::new() - .max_connections(5) - .connect_timeout(Duration::from_secs(3)) - .connect_with(conn_opts) - .await - .expect("can't connect to database"); + let pool = db::get_pool().await; // build our application with some routes + use handlers::*; let app = Router::new() .route( "/", @@ -50,62 +27,9 @@ async fn main() { ) .with_state(pool); - // run it with hyper - //let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); - //axum::serve(listener, app).await.unwrap(); - tracing::info!("binding to 0.0.0.0:3000"); + tracing::debug!("binding to 0.0.0.0:3000"); axum::Server::bind(&SocketAddr::from(([0, 0, 0, 0], 3000))) .serve(app.into_make_service()) .await .unwrap(); } - -// we can extract the connection pool with `State` -async fn using_connection_pool_extractor( - State(pool): State, -) -> Result { - sqlx::query_scalar("select 'hello world from sqlite get'") - .fetch_one(&pool) - .await - .map_err(internal_error) -} - -// we can also write a custom extractor that grabs a connection from the pool -// which setup is appropriate depends on your application -struct DatabaseConnection(sqlx::pool::PoolConnection); - -#[async_trait] -impl FromRequestParts for DatabaseConnection -where - SqlitePool: FromRef, - S: Send + Sync, -{ - type Rejection = (StatusCode, String); - - async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { - let pool = SqlitePool::from_ref(state); - - let conn = pool.acquire().await.map_err(internal_error)?; - - Ok(Self(conn)) - } -} - -async fn using_connection_extractor( - DatabaseConnection(conn): DatabaseConnection, -) -> Result { - let mut conn = conn; - sqlx::query_scalar("select 'hello world from sqlite post'") - .fetch_one(&mut conn) - .await - .map_err(internal_error) -} - -/// Utility function for mapping any error into a `500 Internal Server Error` -/// response. -fn internal_error(err: E) -> (StatusCode, String) -where - E: std::error::Error, -{ - (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) -}