From c848037dcbcfcd86be68ac3177517e97351837fa Mon Sep 17 00:00:00 2001 From: Nicole Tietz-Sokolskaya Date: Sun, 2 Jun 2024 10:55:30 -0400 Subject: [PATCH] Rename Context to Provider --- src/handler/documents.rs | 43 ++++++++++---------- src/handler/home.rs | 9 +++-- src/handler/login.rs | 70 +++++---------------------------- src/handler/projects.rs | 24 +++++------ src/lib.rs | 2 +- src/prelude.rs | 2 +- src/{context.rs => provider.rs} | 21 ++++++++-- src/server.rs | 8 ++-- src/session.rs | 2 +- 9 files changed, 74 insertions(+), 107 deletions(-) rename src/{context.rs => provider.rs} (72%) diff --git a/src/handler/documents.rs b/src/handler/documents.rs index ac46dd2..d4d11d8 100644 --- a/src/handler/documents.rs +++ b/src/handler/documents.rs @@ -12,18 +12,21 @@ use crate::{ }; pub async fn documents_page( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, ) -> Result { if let Some(user) = auth_session.user { - render_documents_page(ctx, user).await + render_documents_page(provider, user).await } else { Ok(Redirect::to("/login").into_response()) } } -async fn render_documents_page(ctx: Context, user: User) -> Result { - let mut db = ctx.db_pool.get().map_err(internal_error)?; +async fn render_documents_page( + provider: Provider, + user: User, +) -> Result { + let mut db = provider.db_pool.get().map_err(internal_error)?; let documents = permissions::query::accessible_documents(&mut db, &user.id).map_err(internal_error)?; let projects = @@ -35,19 +38,19 @@ async fn render_documents_page(ctx: Context, user: User) -> Result projects, }; - Ok(ctx.render_resp("documents/list_documents.html", values)) + Ok(provider.render_resp("documents/list_documents.html", values)) } pub async fn create_document_page( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, ) -> Result { let user = match auth_session.user { Some(user) => user, None => return Ok(Redirect::to("/login").into_response()), }; - let mut db = ctx.db_pool.get().map_err(internal_error)?; + let mut db = provider.db_pool.get().map_err(internal_error)?; let projects = permissions::query::accessible_projects(&mut db, &user.id).map_err(internal_error)?; @@ -56,7 +59,7 @@ pub async fn create_document_page( user => user, projects => projects, }; - Ok(ctx.render_resp("documents/create_document.html", values)) + Ok(provider.render_resp("documents/create_document.html", values)) } #[derive(Debug, Deserialize)] @@ -66,15 +69,15 @@ pub struct CreateDocumentSubmission { } pub async fn create_document_submit( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, form: Form, ) -> Result { let user = match auth_session.user { Some(user) => user, None => return Ok(Redirect::to("/login").into_response()), }; - let mut db = ctx.db_pool.get().map_err(internal_error)?; + let mut db = provider.db_pool.get().map_err(internal_error)?; let project_allowed = permissions::query::check_user_project( &mut db, @@ -102,8 +105,8 @@ pub async fn create_document_submit( } pub async fn edit_document_page( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, Path((id,)): Path<(Uuid,)>, ) -> Result { let user = match auth_session.user { @@ -111,7 +114,7 @@ pub async fn edit_document_page( None => return Ok(Redirect::to("/login").into_response()), }; - let mut db = ctx.db_pool.get().map_err(internal_error)?; + let mut db = provider.db_pool.get().map_err(internal_error)?; let document_allowed = permissions::query::check_user_document( &mut db, @@ -135,7 +138,7 @@ pub async fn edit_document_page( projects => projects, }; - Ok(ctx.render_resp("documents/edit_document.html", values)) + Ok(provider.render_resp("documents/edit_document.html", values)) } #[derive(Debug, Deserialize)] @@ -145,8 +148,8 @@ pub struct EditDocumentSubmission { } pub async fn edit_document_submit( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, Path((document_id,)): Path<(Uuid,)>, form: Form, ) -> Result { @@ -155,7 +158,7 @@ pub async fn edit_document_submit( None => return Ok(Redirect::to("/login").into_response()), }; - let mut db = ctx.db_pool.get().map_err(internal_error)?; + let mut db = provider.db_pool.get().map_err(internal_error)?; let document_allowed = permissions::query::check_user_document( &mut db, diff --git a/src/handler/home.rs b/src/handler/home.rs index 2470d75..f25fef2 100644 --- a/src/handler/home.rs +++ b/src/handler/home.rs @@ -5,9 +5,12 @@ use crate::models::projects::Project; use {crate::permissions, crate::prelude::*}; -pub async fn home_page(State(ctx): State, auth_session: AuthSession) -> Response { +pub async fn home_page( + State(provider): State, + auth_session: AuthSession, +) -> Response { if let Some(user) = auth_session.user { - let mut db = ctx.db_pool.get().unwrap(); + let mut db = provider.db_pool.get().unwrap(); let projects: Vec = permissions::query::accessible_projects(&mut db, &user.id).unwrap(); @@ -16,7 +19,7 @@ pub async fn home_page(State(ctx): State, auth_session: AuthSession projects, }; - ctx.render_resp("home.html", values) + provider.render_resp("home.html", values) } else { Redirect::to("/login").into_response() } diff --git a/src/handler/login.rs b/src/handler/login.rs index 61472c3..eebdb69 100644 --- a/src/handler/login.rs +++ b/src/handler/login.rs @@ -10,23 +10,23 @@ pub struct LoginTemplate { } pub async fn login_page( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, ) -> Response { if auth_session.user.is_some() { return Redirect::to("/").into_response(); } - render_login_page(&ctx, "", "", None) + render_login_page(&provider, "", "", None) } fn render_login_page( - ctx: &Context, + provider: &Provider, username: &str, password: &str, error: Option<&'static str>, ) -> Response { - ctx.render_resp( + provider.render_resp( "login.html", context! { username => username, @@ -39,8 +39,8 @@ fn render_login_page( const LOGIN_ERROR_MSG: &str = "Invalid username or password"; pub async fn login_submit( - State(ctx): State, - mut auth_session: AuthSession, + State(provider): State, + mut auth_session: AuthSession, Form(creds): Form, ) -> Response { match auth_session.authenticate(creds).await { @@ -52,7 +52,7 @@ pub async fn login_submit( Redirect::to("/").into_response() } - Ok(None) => render_login_page(&ctx, "", "", Some(LOGIN_ERROR_MSG)), + Ok(None) => render_login_page(&provider, "", "", Some(LOGIN_ERROR_MSG)), Err(err) => { error!(?err, "error while authenticating user"); internal_server_error() @@ -60,62 +60,10 @@ pub async fn login_submit( } } -pub async fn logout(mut auth_session: AuthSession) -> Response { +pub async fn logout(mut auth_session: AuthSession) -> Response { if let Err(err) = auth_session.logout().await { error!(?err, "error while logging out user"); } Redirect::to("/login").into_response() } - -//const INVALID_LOGIN_MESSAGE: &str = "Invalid username/password, please try again."; -// -//pub async fn login_submission( -// request: HttpRequest, -// context: web::Data, -// form: web::Form, -//) -> impl Responder { -// let mut conn = match context.pool.get() { -// Ok(conn) => conn, -// Err(_) => return internal_server_error(), -// }; -// -// let user = match fetch_user_by_username(&mut conn, &form.username) { -// Ok(Some(user)) => user, -// Ok(None) => { -// return LoginTemplate { -// username: form.username.clone(), -// password: String::new(), -// error: Some(INVALID_LOGIN_MESSAGE.into()), -// } -// .to_response() -// } -// Err(_) => return internal_server_error(), -// }; -// -// if !user.check_password(&form.password) { -// return LoginTemplate { -// username: form.username.clone(), -// password: String::new(), -// error: Some(INVALID_LOGIN_MESSAGE.into()), -// } -// .to_response(); -// } -// -// if Identity::login(&request.extensions(), user.id.to_string()).is_err() { -// return internal_server_error(); -// } -// -// return HttpResponse::Found() -// .append_header(("Location", "/")) -// .finish(); -//} -// -//#[get("/logout")] -//pub async fn logout(user: Option) -> impl Responder { -// if let Some(user) = user { -// user.logout(); -// } -// -// redirect_to_login() -//} diff --git a/src/handler/projects.rs b/src/handler/projects.rs index 33c3b2d..2103145 100644 --- a/src/handler/projects.rs +++ b/src/handler/projects.rs @@ -15,18 +15,18 @@ use crate::{ use super::internal_error; pub async fn projects_page( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, ) -> Response { if let Some(user) = auth_session.user { - render_projects_page(ctx, user).await + render_projects_page(provider, user).await } else { Redirect::to("/login").into_response() } } -async fn render_projects_page(ctx: Context, user: User) -> Response { - let mut db = match ctx.db_pool.get() { +async fn render_projects_page(provider: Provider, user: User) -> Response { + let mut db = match provider.db_pool.get() { Ok(db) => db, Err(err) => { error!(?err, "failed to get db connection"); @@ -40,12 +40,12 @@ async fn render_projects_page(ctx: Context, user: User) -> Response { projects => projects, }; - ctx.render_resp("projects/list_projects.html", values) + provider.render_resp("projects/list_projects.html", values) } pub async fn create_project_page( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, ) -> Response { let user = match auth_session.user { Some(user) => user, @@ -55,7 +55,7 @@ pub async fn create_project_page( let values = context! { user => user, }; - ctx.render_resp("projects/create_project.html", values) + provider.render_resp("projects/create_project.html", values) } #[derive(Debug, Deserialize)] @@ -66,11 +66,11 @@ pub struct CreateProjectSubmission { } pub async fn create_project_submit( - State(ctx): State, - auth_session: AuthSession, + State(provider): State, + auth_session: AuthSession, form: Form, ) -> Result { - let mut db = ctx.db_pool.get().map_err(internal_error)?; + let mut db = provider.db_pool.get().map_err(internal_error)?; let user = match auth_session.user { Some(user) => user, diff --git a/src/lib.rs b/src/lib.rs index 5705e0a..ecc092d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ pub mod config; -pub mod context; pub mod db; pub mod handler; pub mod logging; @@ -7,6 +6,7 @@ pub mod models; pub mod password; pub mod permissions; pub mod prelude; +pub mod provider; pub mod schema; pub mod serialize; pub mod server; diff --git a/src/prelude.rs b/src/prelude.rs index db0341b..9074208 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,4 +1,4 @@ -pub use crate::context::Context; +pub use crate::provider::Provider; pub use axum::extract::State; pub use axum::response::{Html, IntoResponse, Response}; pub use minijinja::context; diff --git a/src/context.rs b/src/provider.rs similarity index 72% rename from src/context.rs rename to src/provider.rs index ce7ca63..268ab58 100644 --- a/src/context.rs +++ b/src/provider.rs @@ -3,27 +3,40 @@ use diesel::{ SqliteConnection, }; use std::sync::Arc; +use thiserror::Error; use minijinja_autoreload::AutoReloader; use crate::{handler::internal_server_error, prelude::*}; pub type ConnectionPool = Pool>; +pub type PooledConnection = diesel::r2d2::PooledConnection>; #[derive(Clone)] -pub struct Context { +pub struct Provider { pub db_pool: ConnectionPool, template_loader: Arc, } -impl Context { - pub fn new(db: ConnectionPool, template_loader: AutoReloader) -> Context { - Context { +#[derive(Error, Debug)] +pub enum ProviderError { + #[error("Error while using the connection pool: {0}")] + R2D2Error(#[from] diesel::r2d2::PoolError), +} + +impl Provider { + pub fn new(db: ConnectionPool, template_loader: AutoReloader) -> Provider { + Provider { db_pool: db, template_loader: Arc::new(template_loader), } } + pub fn db_conn(&self) -> Result { + let conn = self.db_pool.get()?; + Ok(conn) + } + pub fn render(&self, path: &str, data: T) -> anyhow::Result { // TODO: more graceful handling of the potential errors here; this should not use anyhow let env = self.template_loader.acquire_env().unwrap(); diff --git a/src/server.rs b/src/server.rs index aec6cdd..dd4c2bd 100644 --- a/src/server.rs +++ b/src/server.rs @@ -18,7 +18,7 @@ use tracing::Level; use crate::{ config::CommandLineOptions, - context::Context, + provider::Provider, db, handler::{ documents::{ @@ -52,9 +52,9 @@ pub async fn run() -> Result<()> { let session_layer = create_session_manager_layer().await?; - let context = Context::new(db_pool, template_loader); + let provider = Provider::new(db_pool, template_loader); - let auth_backend = context.clone(); + let auth_backend = provider.clone(); let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer.clone()).build(); let trace_layer = TraceLayer::new_for_http() @@ -78,7 +78,7 @@ pub async fn run() -> Result<()> { .layer(trace_layer) .layer(session_layer) .layer(auth_layer) - .with_state(context); + .with_state(provider); let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap(); axum::serve(listener, app).await.unwrap(); diff --git a/src/session.rs b/src/session.rs index dfce60d..f6f5c27 100644 --- a/src/session.rs +++ b/src/session.rs @@ -26,7 +26,7 @@ impl AuthUser for models::users::User { } #[async_trait] -impl AuthnBackend for Context { +impl AuthnBackend for Provider { type User = models::users::User; type Credentials = Credentials; type Error = DbError;