Rename Context to Provider

This commit is contained in:
Nicole Tietz-Sokolskaya 2024-06-02 10:55:30 -04:00
parent af51f3f38f
commit c848037dcb
9 changed files with 74 additions and 107 deletions

View File

@ -12,18 +12,21 @@ use crate::{
}; };
pub async fn documents_page( pub async fn documents_page(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
) -> Result<Response, (StatusCode, String)> { ) -> Result<Response, (StatusCode, String)> {
if let Some(user) = auth_session.user { if let Some(user) = auth_session.user {
render_documents_page(ctx, user).await render_documents_page(provider, user).await
} else { } else {
Ok(Redirect::to("/login").into_response()) Ok(Redirect::to("/login").into_response())
} }
} }
async fn render_documents_page(ctx: Context, user: User) -> Result<Response, (StatusCode, String)> { async fn render_documents_page(
let mut db = ctx.db_pool.get().map_err(internal_error)?; provider: Provider,
user: User,
) -> Result<Response, (StatusCode, String)> {
let mut db = provider.db_pool.get().map_err(internal_error)?;
let documents = let documents =
permissions::query::accessible_documents(&mut db, &user.id).map_err(internal_error)?; permissions::query::accessible_documents(&mut db, &user.id).map_err(internal_error)?;
let projects = let projects =
@ -35,19 +38,19 @@ async fn render_documents_page(ctx: Context, user: User) -> Result<Response, (St
projects => projects, projects => projects,
}; };
Ok(ctx.render_resp("documents/list_documents.html", values)) Ok(provider.render_resp("documents/list_documents.html", values))
} }
pub async fn create_document_page( pub async fn create_document_page(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
) -> Result<Response, (StatusCode, String)> { ) -> Result<Response, (StatusCode, String)> {
let user = match auth_session.user { let user = match auth_session.user {
Some(user) => user, Some(user) => user,
None => return Ok(Redirect::to("/login").into_response()), None => return Ok(Redirect::to("/login").into_response()),
}; };
let mut db = ctx.db_pool.get().map_err(internal_error)?; let mut db = provider.db_pool.get().map_err(internal_error)?;
let projects = let projects =
permissions::query::accessible_projects(&mut db, &user.id).map_err(internal_error)?; permissions::query::accessible_projects(&mut db, &user.id).map_err(internal_error)?;
@ -56,7 +59,7 @@ pub async fn create_document_page(
user => user, user => user,
projects => projects, projects => projects,
}; };
Ok(ctx.render_resp("documents/create_document.html", values)) Ok(provider.render_resp("documents/create_document.html", values))
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -66,15 +69,15 @@ pub struct CreateDocumentSubmission {
} }
pub async fn create_document_submit( pub async fn create_document_submit(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
form: Form<CreateDocumentSubmission>, form: Form<CreateDocumentSubmission>,
) -> Result<Response, (StatusCode, String)> { ) -> Result<Response, (StatusCode, String)> {
let user = match auth_session.user { let user = match auth_session.user {
Some(user) => user, Some(user) => user,
None => return Ok(Redirect::to("/login").into_response()), None => return Ok(Redirect::to("/login").into_response()),
}; };
let mut db = ctx.db_pool.get().map_err(internal_error)?; let mut db = provider.db_pool.get().map_err(internal_error)?;
let project_allowed = permissions::query::check_user_project( let project_allowed = permissions::query::check_user_project(
&mut db, &mut db,
@ -102,8 +105,8 @@ pub async fn create_document_submit(
} }
pub async fn edit_document_page( pub async fn edit_document_page(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
Path((id,)): Path<(Uuid,)>, Path((id,)): Path<(Uuid,)>,
) -> Result<Response, (StatusCode, String)> { ) -> Result<Response, (StatusCode, String)> {
let user = match auth_session.user { let user = match auth_session.user {
@ -111,7 +114,7 @@ pub async fn edit_document_page(
None => return Ok(Redirect::to("/login").into_response()), None => return Ok(Redirect::to("/login").into_response()),
}; };
let mut db = ctx.db_pool.get().map_err(internal_error)?; let mut db = provider.db_pool.get().map_err(internal_error)?;
let document_allowed = permissions::query::check_user_document( let document_allowed = permissions::query::check_user_document(
&mut db, &mut db,
@ -135,7 +138,7 @@ pub async fn edit_document_page(
projects => projects, projects => projects,
}; };
Ok(ctx.render_resp("documents/edit_document.html", values)) Ok(provider.render_resp("documents/edit_document.html", values))
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -145,8 +148,8 @@ pub struct EditDocumentSubmission {
} }
pub async fn edit_document_submit( pub async fn edit_document_submit(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
Path((document_id,)): Path<(Uuid,)>, Path((document_id,)): Path<(Uuid,)>,
form: Form<EditDocumentSubmission>, form: Form<EditDocumentSubmission>,
) -> Result<Response, (StatusCode, String)> { ) -> Result<Response, (StatusCode, String)> {
@ -155,7 +158,7 @@ pub async fn edit_document_submit(
None => return Ok(Redirect::to("/login").into_response()), None => return Ok(Redirect::to("/login").into_response()),
}; };
let mut db = ctx.db_pool.get().map_err(internal_error)?; let mut db = provider.db_pool.get().map_err(internal_error)?;
let document_allowed = permissions::query::check_user_document( let document_allowed = permissions::query::check_user_document(
&mut db, &mut db,

View File

@ -5,9 +5,12 @@ use crate::models::projects::Project;
use {crate::permissions, crate::prelude::*}; use {crate::permissions, crate::prelude::*};
pub async fn home_page(State(ctx): State<Context>, auth_session: AuthSession<Context>) -> Response { pub async fn home_page(
State(provider): State<Provider>,
auth_session: AuthSession<Provider>,
) -> Response {
if let Some(user) = auth_session.user { if let Some(user) = auth_session.user {
let mut db = ctx.db_pool.get().unwrap(); let mut db = provider.db_pool.get().unwrap();
let projects: Vec<Project> = let projects: Vec<Project> =
permissions::query::accessible_projects(&mut db, &user.id).unwrap(); permissions::query::accessible_projects(&mut db, &user.id).unwrap();
@ -16,7 +19,7 @@ pub async fn home_page(State(ctx): State<Context>, auth_session: AuthSession<Con
projects => projects, projects => projects,
}; };
ctx.render_resp("home.html", values) provider.render_resp("home.html", values)
} else { } else {
Redirect::to("/login").into_response() Redirect::to("/login").into_response()
} }

View File

@ -10,23 +10,23 @@ pub struct LoginTemplate {
} }
pub async fn login_page( pub async fn login_page(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
) -> Response { ) -> Response {
if auth_session.user.is_some() { if auth_session.user.is_some() {
return Redirect::to("/").into_response(); return Redirect::to("/").into_response();
} }
render_login_page(&ctx, "", "", None) render_login_page(&provider, "", "", None)
} }
fn render_login_page( fn render_login_page(
ctx: &Context, provider: &Provider,
username: &str, username: &str,
password: &str, password: &str,
error: Option<&'static str>, error: Option<&'static str>,
) -> Response { ) -> Response {
ctx.render_resp( provider.render_resp(
"login.html", "login.html",
context! { context! {
username => username, username => username,
@ -39,8 +39,8 @@ fn render_login_page(
const LOGIN_ERROR_MSG: &str = "Invalid username or password"; const LOGIN_ERROR_MSG: &str = "Invalid username or password";
pub async fn login_submit( pub async fn login_submit(
State(ctx): State<Context>, State(provider): State<Provider>,
mut auth_session: AuthSession<Context>, mut auth_session: AuthSession<Provider>,
Form(creds): Form<Credentials>, Form(creds): Form<Credentials>,
) -> Response { ) -> Response {
match auth_session.authenticate(creds).await { match auth_session.authenticate(creds).await {
@ -52,7 +52,7 @@ pub async fn login_submit(
Redirect::to("/").into_response() Redirect::to("/").into_response()
} }
Ok(None) => render_login_page(&ctx, "", "", Some(LOGIN_ERROR_MSG)), Ok(None) => render_login_page(&provider, "", "", Some(LOGIN_ERROR_MSG)),
Err(err) => { Err(err) => {
error!(?err, "error while authenticating user"); error!(?err, "error while authenticating user");
internal_server_error() internal_server_error()
@ -60,62 +60,10 @@ pub async fn login_submit(
} }
} }
pub async fn logout(mut auth_session: AuthSession<Context>) -> Response { pub async fn logout(mut auth_session: AuthSession<Provider>) -> Response {
if let Err(err) = auth_session.logout().await { if let Err(err) = auth_session.logout().await {
error!(?err, "error while logging out user"); error!(?err, "error while logging out user");
} }
Redirect::to("/login").into_response() Redirect::to("/login").into_response()
} }
//const INVALID_LOGIN_MESSAGE: &str = "Invalid username/password, please try again.";
//
//pub async fn login_submission(
// request: HttpRequest,
// context: web::Data<Context>,
// form: web::Form<LoginForm>,
//) -> impl Responder {
// let mut conn = match context.pool.get() {
// Ok(conn) => conn,
// Err(_) => return internal_server_error(),
// };
//
// let user = match fetch_user_by_username(&mut conn, &form.username) {
// Ok(Some(user)) => user,
// Ok(None) => {
// return LoginTemplate {
// username: form.username.clone(),
// password: String::new(),
// error: Some(INVALID_LOGIN_MESSAGE.into()),
// }
// .to_response()
// }
// Err(_) => return internal_server_error(),
// };
//
// if !user.check_password(&form.password) {
// return LoginTemplate {
// username: form.username.clone(),
// password: String::new(),
// error: Some(INVALID_LOGIN_MESSAGE.into()),
// }
// .to_response();
// }
//
// if Identity::login(&request.extensions(), user.id.to_string()).is_err() {
// return internal_server_error();
// }
//
// return HttpResponse::Found()
// .append_header(("Location", "/"))
// .finish();
//}
//
//#[get("/logout")]
//pub async fn logout(user: Option<Identity>) -> impl Responder {
// if let Some(user) = user {
// user.logout();
// }
//
// redirect_to_login()
//}

View File

@ -15,18 +15,18 @@ use crate::{
use super::internal_error; use super::internal_error;
pub async fn projects_page( pub async fn projects_page(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
) -> Response { ) -> Response {
if let Some(user) = auth_session.user { if let Some(user) = auth_session.user {
render_projects_page(ctx, user).await render_projects_page(provider, user).await
} else { } else {
Redirect::to("/login").into_response() Redirect::to("/login").into_response()
} }
} }
async fn render_projects_page(ctx: Context, user: User) -> Response { async fn render_projects_page(provider: Provider, user: User) -> Response {
let mut db = match ctx.db_pool.get() { let mut db = match provider.db_pool.get() {
Ok(db) => db, Ok(db) => db,
Err(err) => { Err(err) => {
error!(?err, "failed to get db connection"); error!(?err, "failed to get db connection");
@ -40,12 +40,12 @@ async fn render_projects_page(ctx: Context, user: User) -> Response {
projects => projects, projects => projects,
}; };
ctx.render_resp("projects/list_projects.html", values) provider.render_resp("projects/list_projects.html", values)
} }
pub async fn create_project_page( pub async fn create_project_page(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
) -> Response { ) -> Response {
let user = match auth_session.user { let user = match auth_session.user {
Some(user) => user, Some(user) => user,
@ -55,7 +55,7 @@ pub async fn create_project_page(
let values = context! { let values = context! {
user => user, user => user,
}; };
ctx.render_resp("projects/create_project.html", values) provider.render_resp("projects/create_project.html", values)
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -66,11 +66,11 @@ pub struct CreateProjectSubmission {
} }
pub async fn create_project_submit( pub async fn create_project_submit(
State(ctx): State<Context>, State(provider): State<Provider>,
auth_session: AuthSession<Context>, auth_session: AuthSession<Provider>,
form: Form<CreateProjectSubmission>, form: Form<CreateProjectSubmission>,
) -> Result<Response, (StatusCode, String)> { ) -> Result<Response, (StatusCode, String)> {
let mut db = ctx.db_pool.get().map_err(internal_error)?; let mut db = provider.db_pool.get().map_err(internal_error)?;
let user = match auth_session.user { let user = match auth_session.user {
Some(user) => user, Some(user) => user,

View File

@ -1,5 +1,4 @@
pub mod config; pub mod config;
pub mod context;
pub mod db; pub mod db;
pub mod handler; pub mod handler;
pub mod logging; pub mod logging;
@ -7,6 +6,7 @@ pub mod models;
pub mod password; pub mod password;
pub mod permissions; pub mod permissions;
pub mod prelude; pub mod prelude;
pub mod provider;
pub mod schema; pub mod schema;
pub mod serialize; pub mod serialize;
pub mod server; pub mod server;

View File

@ -1,4 +1,4 @@
pub use crate::context::Context; pub use crate::provider::Provider;
pub use axum::extract::State; pub use axum::extract::State;
pub use axum::response::{Html, IntoResponse, Response}; pub use axum::response::{Html, IntoResponse, Response};
pub use minijinja::context; pub use minijinja::context;

View File

@ -3,27 +3,40 @@ use diesel::{
SqliteConnection, SqliteConnection,
}; };
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error;
use minijinja_autoreload::AutoReloader; use minijinja_autoreload::AutoReloader;
use crate::{handler::internal_server_error, prelude::*}; use crate::{handler::internal_server_error, prelude::*};
pub type ConnectionPool = Pool<ConnectionManager<SqliteConnection>>; pub type ConnectionPool = Pool<ConnectionManager<SqliteConnection>>;
pub type PooledConnection = diesel::r2d2::PooledConnection<ConnectionManager<SqliteConnection>>;
#[derive(Clone)] #[derive(Clone)]
pub struct Context { pub struct Provider {
pub db_pool: ConnectionPool, pub db_pool: ConnectionPool,
template_loader: Arc<AutoReloader>, template_loader: Arc<AutoReloader>,
} }
impl Context { #[derive(Error, Debug)]
pub fn new(db: ConnectionPool, template_loader: AutoReloader) -> Context { pub enum ProviderError {
Context { #[error("Error while using the connection pool: {0}")]
R2D2Error(#[from] diesel::r2d2::PoolError),
}
impl Provider {
pub fn new(db: ConnectionPool, template_loader: AutoReloader) -> Provider {
Provider {
db_pool: db, db_pool: db,
template_loader: Arc::new(template_loader), template_loader: Arc::new(template_loader),
} }
} }
pub fn db_conn(&self) -> Result<PooledConnection, ProviderError> {
let conn = self.db_pool.get()?;
Ok(conn)
}
pub fn render<T: Serialize>(&self, path: &str, data: T) -> anyhow::Result<String> { pub fn render<T: Serialize>(&self, path: &str, data: T) -> anyhow::Result<String> {
// TODO: more graceful handling of the potential errors here; this should not use anyhow // TODO: more graceful handling of the potential errors here; this should not use anyhow
let env = self.template_loader.acquire_env().unwrap(); let env = self.template_loader.acquire_env().unwrap();

View File

@ -18,7 +18,7 @@ use tracing::Level;
use crate::{ use crate::{
config::CommandLineOptions, config::CommandLineOptions,
context::Context, provider::Provider,
db, db,
handler::{ handler::{
documents::{ documents::{
@ -52,9 +52,9 @@ pub async fn run() -> Result<()> {
let session_layer = create_session_manager_layer().await?; let session_layer = create_session_manager_layer().await?;
let context = Context::new(db_pool, template_loader); let provider = Provider::new(db_pool, template_loader);
let auth_backend = context.clone(); let auth_backend = provider.clone();
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer.clone()).build(); let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer.clone()).build();
let trace_layer = TraceLayer::new_for_http() let trace_layer = TraceLayer::new_for_http()
@ -78,7 +78,7 @@ pub async fn run() -> Result<()> {
.layer(trace_layer) .layer(trace_layer)
.layer(session_layer) .layer(session_layer)
.layer(auth_layer) .layer(auth_layer)
.with_state(context); .with_state(provider);
let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap(); let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();

View File

@ -26,7 +26,7 @@ impl AuthUser for models::users::User {
} }
#[async_trait] #[async_trait]
impl AuthnBackend for Context { impl AuthnBackend for Provider {
type User = models::users::User; type User = models::users::User;
type Credentials = Credentials; type Credentials = Credentials;
type Error = DbError; type Error = DbError;