refactoring

This commit is contained in:
Nicole Tietz-Sokolskaya 2024-06-02 11:21:50 -04:00
parent 3bf0f8de74
commit 012a3175bb
18 changed files with 75 additions and 163 deletions

View file

@ -11,12 +11,7 @@ pub async fn main() -> Result<()> {
let db_url = dotenvy::var("DATABASE_URL")?; let db_url = dotenvy::var("DATABASE_URL")?;
match AdminCli::parse().command { match AdminCli::parse().command {
AdminCommand::CreateUser { AdminCommand::CreateUser { name, email, username, password } => {
name,
email,
username,
password,
} => {
let password = match password { let password = match password {
Some(p) => p, Some(p) => p,
None => { None => {
@ -43,12 +38,7 @@ struct AdminCli {
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
pub enum AdminCommand { pub enum AdminCommand {
CreateUser { CreateUser { name: String, email: String, username: String, password: Option<String> },
name: String,
email: String,
username: String,
password: Option<String>,
},
ListUsers, ListUsers,
} }

View file

@ -26,9 +26,7 @@ pub fn establish_connection(url: &str) -> SqliteConnection {
/// Panics if the connection pool cannot be created. /// Panics if the connection pool cannot be created.
pub fn build_connection_pool(url: &str) -> Pool<ConnectionManager<SqliteConnection>> { pub fn build_connection_pool(url: &str) -> Pool<ConnectionManager<SqliteConnection>> {
let manager = ConnectionManager::<SqliteConnection>::new(url); let manager = ConnectionManager::<SqliteConnection>::new(url);
Pool::builder() Pool::builder().build(manager).expect("Failed to create connection pool.")
.build(manager)
.expect("Failed to create connection pool.")
} }
/// Runs any pending migrations. /// Runs any pending migrations.

View file

@ -4,24 +4,13 @@ pub mod login;
pub mod projects; pub mod projects;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::Response;
pub use login::{login_page, login_submit}; pub use login::{login_page, login_submit};
use tracing::error; use tracing::error;
pub fn internal_server_error() -> Response {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body("Internal Server Error".into())
.unwrap()
}
pub fn internal_error<E>(err: E) -> (StatusCode, String) pub fn internal_error<E>(err: E) -> (StatusCode, String)
where where
E: std::error::Error, E: std::error::Error,
{ {
error!(?err, "internal error"); error!(?err, "internal error");
( (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error".into())
StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error".into(),
)
} }

View file

@ -38,7 +38,7 @@ async fn render_documents_page(
projects => projects, projects => projects,
}; };
Ok(provider.render_resp("documents/list_documents.html", values)) provider.render_resp("documents/list_documents.html", values)
} }
pub async fn create_document_page( pub async fn create_document_page(
@ -59,7 +59,7 @@ pub async fn create_document_page(
user => user, user => user,
projects => projects, projects => projects,
}; };
Ok(provider.render_resp("documents/create_document.html", values)) provider.render_resp("documents/create_document.html", values)
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -138,7 +138,7 @@ pub async fn edit_document_page(
projects => projects, projects => projects,
}; };
Ok(provider.render_resp("documents/edit_document.html", values)) provider.render_resp("documents/edit_document.html", values)
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]

View file

@ -1,3 +1,4 @@
use axum::http::StatusCode;
use axum::response::Redirect; use axum::response::Redirect;
use axum_login::AuthSession; use axum_login::AuthSession;
@ -8,7 +9,7 @@ use crate::prelude::*;
pub async fn home_page( pub async fn home_page(
State(provider): State<Provider>, State(provider): State<Provider>,
auth_session: AuthSession<Provider>, auth_session: AuthSession<Provider>,
) -> Response { ) -> Result<Response, (StatusCode, String)> {
if let Some(user) = auth_session.user { if let Some(user) = auth_session.user {
let mut db = provider.db_pool.get().unwrap(); let mut db = provider.db_pool.get().unwrap();
let projects: Vec<Project> = let projects: Vec<Project> =
@ -21,6 +22,6 @@ pub async fn home_page(
provider.render_resp("home.html", values) provider.render_resp("home.html", values)
} else { } else {
Redirect::to("/login").into_response() Ok(Redirect::to("/login").into_response())
} }
} }

View file

@ -1,8 +1,9 @@
use axum::http::StatusCode;
use axum::response::Redirect; use axum::response::Redirect;
use axum::Form; use axum::Form;
use axum_login::AuthSession; use axum_login::AuthSession;
use crate::handler::internal_server_error; use super::internal_error;
use crate::prelude::*; use crate::prelude::*;
use crate::session::Credentials; use crate::session::Credentials;
@ -15,12 +16,12 @@ pub struct LoginTemplate {
pub async fn login_page( pub async fn login_page(
State(provider): State<Provider>, State(provider): State<Provider>,
auth_session: AuthSession<Provider>, auth_session: AuthSession<Provider>,
) -> Response { ) -> Result<Response, (StatusCode, String)> {
if auth_session.user.is_some() { if let Some(_user) = auth_session.user {
return Redirect::to("/").into_response(); Ok(Redirect::to("/").into_response())
} else {
render_login_page(&provider, "", "", None)
} }
render_login_page(&provider, "", "", None)
} }
fn render_login_page( fn render_login_page(
@ -28,7 +29,7 @@ fn render_login_page(
username: &str, username: &str,
password: &str, password: &str,
error: Option<&'static str>, error: Option<&'static str>,
) -> Response { ) -> Result<Response, (StatusCode, String)> {
provider.render_resp( provider.render_resp(
"login.html", "login.html",
context! { context! {
@ -45,21 +46,12 @@ pub async fn login_submit(
State(provider): State<Provider>, State(provider): State<Provider>,
mut auth_session: AuthSession<Provider>, mut auth_session: AuthSession<Provider>,
Form(creds): Form<Credentials>, Form(creds): Form<Credentials>,
) -> Response { ) -> Result<Response, (StatusCode, String)> {
match auth_session.authenticate(creds).await { if let Some(user) = auth_session.authenticate(creds).await.map_err(internal_error)? {
Ok(Some(user)) => { let _ = auth_session.login(&user).await.map_err(internal_error)?;
if let Err(err) = auth_session.login(&user).await { Ok(Redirect::to("/").into_response())
error!(?err, "error while logging in user"); } else {
return internal_server_error(); render_login_page(&provider, "", "", Some(LOGIN_ERROR_MSG))
}
Redirect::to("/").into_response()
}
Ok(None) => render_login_page(&provider, "", "", Some(LOGIN_ERROR_MSG)),
Err(err) => {
error!(?err, "error while authenticating user");
internal_server_error()
}
} }
} }

View file

@ -4,7 +4,6 @@ use axum::Form;
use axum_login::AuthSession; use axum_login::AuthSession;
use super::internal_error; use super::internal_error;
use crate::handler::internal_server_error;
use crate::models::project_memberships::{self, ProjectRole}; use crate::models::project_memberships::{self, ProjectRole};
use crate::models::projects::{self, NewProject}; use crate::models::projects::{self, NewProject};
use crate::models::users::User; use crate::models::users::User;
@ -14,22 +13,19 @@ use crate::prelude::*;
pub async fn projects_page( pub async fn projects_page(
State(provider): State<Provider>, State(provider): State<Provider>,
auth_session: AuthSession<Provider>, auth_session: AuthSession<Provider>,
) -> Response { ) -> Result<Response, (StatusCode, String)> {
if let Some(user) = auth_session.user { if let Some(user) = auth_session.user {
render_projects_page(provider, user).await render_projects_page(provider, user).await
} else { } else {
Redirect::to("/login").into_response() Ok(Redirect::to("/login").into_response())
} }
} }
async fn render_projects_page(provider: Provider, user: User) -> Response { async fn render_projects_page(
let mut db = match provider.db_pool.get() { provider: Provider,
Ok(db) => db, user: User,
Err(err) => { ) -> Result<Response, (StatusCode, String)> {
error!(?err, "failed to get db connection"); let mut db = provider.db_pool.get().map_err(internal_error)?;
return internal_server_error();
}
};
let projects = permissions::query::accessible_projects(&mut db, &user.id).unwrap_or_default(); let projects = permissions::query::accessible_projects(&mut db, &user.id).unwrap_or_default();
let values = context! { let values = context! {
@ -43,10 +39,10 @@ async fn render_projects_page(provider: Provider, user: User) -> Response {
pub async fn create_project_page( pub async fn create_project_page(
State(provider): State<Provider>, State(provider): State<Provider>,
auth_session: AuthSession<Provider>, auth_session: AuthSession<Provider>,
) -> Response { ) -> Result<Response, (StatusCode, String)> {
let user = match auth_session.user { let user = match auth_session.user {
Some(user) => user, Some(user) => user,
None => return Redirect::to("/login").into_response(), None => return Ok(Redirect::to("/login").into_response()),
}; };
let values = context! { let values = context! {

View file

@ -1,7 +1,5 @@
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
pub fn setup_logging() { pub fn setup_logging() {
tracing_subscriber::fmt() tracing_subscriber::fmt().with_env_filter(EnvFilter::from_default_env()).init();
.with_env_filter(EnvFilter::from_default_env())
.init();
} }

View file

@ -45,13 +45,9 @@ pub mod query {
db: &mut SqliteConnection, db: &mut SqliteConnection,
new_document: NewDocument, new_document: NewDocument,
) -> Result<Document, DbError> { ) -> Result<Document, DbError> {
diesel::insert_into(dsl::documents) diesel::insert_into(dsl::documents).values(&new_document).execute(db)?;
.values(&new_document)
.execute(db)?;
let document = dsl::documents let document = dsl::documents.filter(dsl::id.eq(&new_document.id)).first(db)?;
.filter(dsl::id.eq(&new_document.id))
.first(db)?;
Ok(document) Ok(document)
} }
@ -75,10 +71,8 @@ pub mod query {
db: &mut SqliteConnection, db: &mut SqliteConnection,
document_id: &str, document_id: &str,
) -> Result<Option<Document>, DbError> { ) -> Result<Option<Document>, DbError> {
let document = dsl::documents let document =
.filter(dsl::id.eq(document_id)) dsl::documents.filter(dsl::id.eq(document_id)).first::<Document>(db).optional()?;
.first::<Document>(db)
.optional()?;
Ok(document) Ok(document)
} }

View file

@ -84,9 +84,8 @@ pub mod query {
role, role,
}; };
let membership = diesel::insert_into(pm::project_memberships) let membership =
.values(new_membership) diesel::insert_into(pm::project_memberships).values(new_membership).get_result(db)?;
.get_result(db)?;
Ok(membership) Ok(membership)
} }

View file

@ -28,13 +28,7 @@ pub struct NewProject {
impl NewProject { impl NewProject {
pub fn new(creator_id: String, name: String, description: String, key: String) -> Self { pub fn new(creator_id: String, name: String, description: String, key: String) -> Self {
Self { Self { id: Uuid::now_v7().to_string(), creator_id, name, description, key }
id: Uuid::now_v7().to_string(),
creator_id,
name,
description,
key,
}
} }
} }
@ -42,9 +36,8 @@ pub mod query {
use super::*; use super::*;
pub fn for_user(db: &mut SqliteConnection, user_id: String) -> Result<Vec<Project>, DbError> { pub fn for_user(db: &mut SqliteConnection, user_id: String) -> Result<Vec<Project>, DbError> {
let projects = dsl::projects let projects =
.filter(dsl::creator_id.eq(user_id.to_string())) dsl::projects.filter(dsl::creator_id.eq(user_id.to_string())).load::<Project>(db)?;
.load::<Project>(db)?;
Ok(projects) Ok(projects)
} }
@ -54,9 +47,7 @@ pub mod query {
) -> Result<Project, diesel::result::Error> { ) -> Result<Project, diesel::result::Error> {
use crate::schema::projects::dsl as p; use crate::schema::projects::dsl as p;
let project = diesel::insert_into(p::projects) let project = diesel::insert_into(p::projects).values(new_project).get_result(db)?;
.values(new_project)
.get_result(db)?;
Ok(project) Ok(project)
} }

View file

@ -3,9 +3,9 @@ use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
use super::DbError; use super::DbError;
use crate::db::ValidationError;
use crate::password; use crate::password;
use crate::schema::users::dsl; use crate::schema::users::dsl;
use crate::validation::ValidationError;
#[derive(Queryable, Selectable, Debug, Clone, Serialize)] #[derive(Queryable, Selectable, Debug, Clone, Serialize)]
#[diesel(table_name = crate::schema::users)] #[diesel(table_name = crate::schema::users)]
@ -31,13 +31,7 @@ pub struct NewUser {
impl NewUser { impl NewUser {
pub fn new(name: String, username: String, email: String, password: String) -> Self { pub fn new(name: String, username: String, email: String, password: String) -> Self {
let password_hash = password::hash(&password); let password_hash = password::hash(&password);
Self { Self { id: Uuid::now_v7().to_string(), name, username, email, password_hash }
id: Uuid::now_v7().to_string(),
name,
username,
email,
password_hash,
}
} }
pub fn validate(&self) -> Result<(), Vec<ValidationError>> { pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
@ -82,20 +76,14 @@ impl<'a> Query<'a> {
} }
pub fn by_username(&mut self, username: &str) -> Result<User, DbError> { pub fn by_username(&mut self, username: &str) -> Result<User, DbError> {
let user = dsl::users let user = dsl::users.filter(dsl::username.eq(username)).first::<User>(self.db)?;
.filter(dsl::username.eq(username))
.first::<User>(self.db)?;
Ok(user) Ok(user)
} }
pub fn create(&mut self, new_user: NewUser) -> Result<User, DbError> { pub fn create(&mut self, new_user: NewUser) -> Result<User, DbError> {
let _ = diesel::insert_into(dsl::users) let _ = diesel::insert_into(dsl::users).values(&new_user).execute(self.db)?;
.values(&new_user)
.execute(self.db)?;
let new_user = dsl::users let new_user = dsl::users.filter(dsl::id.eq(&new_user.id)).first::<User>(self.db)?;
.filter(dsl::id.eq(&new_user.id))
.first::<User>(self.db)?;
Ok(new_user) Ok(new_user)
} }

View file

@ -9,18 +9,14 @@ pub fn verify(hash: &str, password: &str) -> bool {
Err(_) => return false, // TODO: log an error Err(_) => return false, // TODO: log an error
}; };
Argon2::default() Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
} }
/// Hashes the given password. /// Hashes the given password.
pub fn hash(password: &str) -> String { pub fn hash(password: &str) -> String {
let salt = SaltString::generate(&mut OsRng); let salt = SaltString::generate(&mut OsRng);
let hash = Argon2::default() let hash = Argon2::default().hash_password(password.as_bytes(), &salt).unwrap();
.hash_password(password.as_bytes(), &salt)
.unwrap();
hash.to_string() hash.to_string()
} }

View file

@ -49,10 +49,8 @@ pub mod query {
) -> Result<bool, diesel::result::Error> { ) -> Result<bool, diesel::result::Error> {
use crate::schema::documents::dsl as d; use crate::schema::documents::dsl as d;
let document = d::documents let document =
.filter(d::id.eq(document_id)) d::documents.filter(d::id.eq(document_id)).first::<Document>(db).optional()?;
.first::<Document>(db)
.optional()?;
match document { match document {
Some(doc) => check_user_project(db, user_id, &doc.project_id, permission), Some(doc) => check_user_project(db, user_id, &doc.project_id, permission),
@ -83,9 +81,7 @@ pub mod query {
use crate::schema::projects::dsl as p; use crate::schema::projects::dsl as p;
let project_ids = accessible_project_ids(db, user_id)?; let project_ids = accessible_project_ids(db, user_id)?;
let projects = p::projects let projects = p::projects.filter(p::id.eq_any(project_ids)).load::<Project>(db)?;
.filter(p::id.eq_any(project_ids))
.load::<Project>(db)?;
Ok(projects) Ok(projects)
} }
@ -111,10 +107,7 @@ pub mod query {
.select(d::id) .select(d::id)
.load::<String>(db)?; .load::<String>(db)?;
let document_ids = direct_documents let document_ids = direct_documents.into_iter().chain(project_documents).collect();
.into_iter()
.chain(project_documents)
.collect();
Ok(document_ids) Ok(document_ids)
} }
@ -128,9 +121,7 @@ pub mod query {
use crate::schema::documents::dsl as d; use crate::schema::documents::dsl as d;
let document_ids = accessible_document_ids(db, user_id)?; let document_ids = accessible_document_ids(db, user_id)?;
let documents = d::documents let documents = d::documents.filter(d::id.eq_any(document_ids)).load::<Document>(db)?;
.filter(d::id.eq_any(document_ids))
.load::<Document>(db)?;
Ok(documents) Ok(documents)
} }

View file

@ -1,11 +1,12 @@
use std::sync::Arc; use std::sync::Arc;
use axum::http::StatusCode;
use diesel::r2d2::{ConnectionManager, Pool}; use diesel::r2d2::{ConnectionManager, Pool};
use diesel::SqliteConnection; use diesel::SqliteConnection;
use minijinja_autoreload::AutoReloader; use minijinja_autoreload::AutoReloader;
use thiserror::Error; use thiserror::Error;
use crate::handler::internal_server_error; use crate::handler::internal_error;
use crate::prelude::*; use crate::prelude::*;
pub type ConnectionPool = Pool<ConnectionManager<SqliteConnection>>; pub type ConnectionPool = Pool<ConnectionManager<SqliteConnection>>;
@ -21,14 +22,14 @@ pub struct Provider {
pub enum ProviderError { pub enum ProviderError {
#[error("Error while using the connection pool: {0}")] #[error("Error while using the connection pool: {0}")]
R2D2Error(#[from] diesel::r2d2::PoolError), R2D2Error(#[from] diesel::r2d2::PoolError),
#[error("Error while rendering template: {0}")]
TemplateError(#[from] minijinja::Error),
} }
impl Provider { impl Provider {
pub fn new(db: ConnectionPool, template_loader: AutoReloader) -> Provider { pub fn new(db: ConnectionPool, template_loader: AutoReloader) -> Provider {
Provider { Provider { db_pool: db, template_loader: Arc::new(template_loader) }
db_pool: db,
template_loader: Arc::new(template_loader),
}
} }
pub fn db_conn(&self) -> Result<PooledConnection, ProviderError> { pub fn db_conn(&self) -> Result<PooledConnection, ProviderError> {
@ -36,23 +37,19 @@ impl Provider {
Ok(conn) Ok(conn)
} }
pub fn render<T: Serialize>(&self, path: &str, data: T) -> anyhow::Result<String> { pub fn render<T: Serialize>(&self, path: &str, data: T) -> Result<String, ProviderError> {
// TODO: more graceful handling of the potential errors here; this should not
// use anyhow
let env = self.template_loader.acquire_env().unwrap(); let env = self.template_loader.acquire_env().unwrap();
let template = env.get_template(path)?; let template = env.get_template(path)?;
let rendered = template.render(data)?; let rendered = template.render(data)?;
Ok(rendered) Ok(rendered)
} }
pub fn render_resp<T: Serialize>(&self, path: &str, data: T) -> Response { pub fn render_resp<T: Serialize>(
let rendered = self.render(path, data); &self,
match rendered { path: &str,
Ok(rendered) => Html(rendered).into_response(), data: T,
Err(err) => { ) -> Result<Response, (StatusCode, String)> {
error!(?err, "error while rendering template"); let rendered = self.render(path, data).map_err(internal_error)?;
internal_server_error() Ok(Html(rendered).into_response())
}
}
} }
} }

View file

@ -5,7 +5,7 @@ use crate::models::{self, users, DbError};
use crate::password; use crate::password;
use crate::prelude::*; use crate::prelude::*;
#[derive(Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct Credentials { pub struct Credentials {
pub username: String, pub username: String,
pub password: String, pub password: String,

View file

@ -24,13 +24,8 @@ pub fn setup_filters(env: &mut Environment) {
pub fn heroicon_filter(name: String, classes: Option<String>) -> Result<String, Error> { pub fn heroicon_filter(name: String, classes: Option<String>) -> Result<String, Error> {
let class = classes.unwrap_or_else(|| "".to_owned()); let class = classes.unwrap_or_else(|| "".to_owned());
let attrs = IconAttrs::default() let attrs = IconAttrs::default().class(&class).fill("none").stroke_color("currentColor");
.class(&class)
.fill("none")
.stroke_color("currentColor");
free_icons::heroicons(&name, true, attrs).ok_or(Error::new( free_icons::heroicons(&name, true, attrs)
ErrorKind::TemplateNotFound, .ok_or(Error::new(ErrorKind::TemplateNotFound, "cannot find template for requested icon"))
"cannot find template for requested icon",
))
} }

View file

@ -8,9 +8,6 @@ pub struct ValidationError {
impl ValidationError { impl ValidationError {
pub fn on(field: &str, message: &str) -> ValidationError { pub fn on(field: &str, message: &str) -> ValidationError {
ValidationError { ValidationError { field: field.to_owned(), message: message.to_owned() }
field: field.to_owned(),
message: message.to_owned(),
}
} }
} }