diff --git a/Cargo.lock b/Cargo.lock index 8fae242..56ccb4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1293,6 +1293,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1503,6 +1504,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -2527,7 +2538,8 @@ dependencies = [ "justerror", "optional_optional_user", "password-hash", - "rand_core", + "rand", + "rand_distr", "serde", "serde_test", "sqlx", diff --git a/Cargo.toml b/Cargo.toml index 28da39b..c218213 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,33 +5,35 @@ edition = "2021" default-run = "witch_watch" [dependencies] -axum = { version = "0.6", features = ["macros", "headers"] } +# local proc macro +optional_optional_user = {path = "optional_optional_user"} + +# regular external deps +argon2 = "0.5" askama = { version = "0.12", features = ["with-axum"] } askama_axum = "0.3" -axum-macros = "0.3" -tokio = { version = "1", features = ["full", "tracing"], default-features = false } -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tower = { version = "0.4", features = ["util", "timeout"], default-features = false } -tower-http = { version = "0.4", features = ["add-extension", "trace"] } -serde = { version = "1", features = ["derive"] } -sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time"] } -argon2 = "0.5" -rand_core = { version = "0.6", features = ["getrandom"] } -thiserror = "1" -justerror = "1" -password-hash = { version = "0.5", features = ["std", "getrandom"] } -axum-login = { version = "0.5", features = ["sqlite", "sqlx"] } -unicode-segmentation = "1" async-session = "3" -ulid = { version = "1", features = ["rand"] } - -# proc macros: -optional_optional_user = {path = "optional_optional_user"} +axum = { version = "0.6", features = ["macros", "headers"] } +axum-login = { version = "0.5", features = ["sqlite", "sqlx"] } +axum-macros = "0.3" chrono = { version = "0.4", default-features = false, features = ["std", "clock"] } clap = { version = "4.3.10", features = ["derive", "env", "unicode", "suggestions", "usage"] } +justerror = "1" +password-hash = { version = "0.5", features = ["std", "getrandom"] } +rand = "0.8" +serde = { version = "1", features = ["derive"] } +sqlx = { version = "0.6", default-features = false, features = ["runtime-tokio-rustls", "any", "sqlite", "chrono", "time"] } +thiserror = "1" +tokio = { version = "1", features = ["full", "tracing"], default-features = false } tokio-retry = "0.3.0" tokio-stream = "0.1.14" +tower = { version = "0.4", features = ["util", "timeout"], default-features = false } +tower-http = { version = "0.4", features = ["add-extension", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +ulid = { version = "1", features = ["rand"] } +unicode-segmentation = "1" +rand_distr = "0.4.3" [dev-dependencies] axum-test = "9.0.0" diff --git a/results.txt b/results.txt index d3171b2..c8d9d82 100644 --- a/results.txt +++ b/results.txt @@ -1,5 +1,3 @@ --rw-r--r-- 1 ardent ardent 1.6M Jul 4 12:27 .witch-watch.db --rw-r--r-- 1 ardent ardent 161K Jul 4 12:29 .witch-watch.db-wal --rw-r--r-- 1 ardent ardent 32K Jul 4 12:29 .witch-watch.db-shm +-rw-r--r-- 1 ardent ardent 17M Jul 6 10:05 /home/ardent/.witch-watch.db -4 seconds wall to add 10k movies, added by the omega user. +6 seconds to add 98,713 watch quests. diff --git a/src/bin/import_omega.rs b/src/bin/import_omega.rs index 893ca7b..268cb3d 100644 --- a/src/bin/import_omega.rs +++ b/src/bin/import_omega.rs @@ -1,16 +1,8 @@ -use std::{ffi::OsString, pin::Pin, time::Duration}; +use std::{ffi::OsString, time::Duration}; use clap::Parser; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; -use tokio::task::JoinSet; -use tokio_retry::Retry; -use tokio_stream::{Stream, StreamExt}; -use witch_watch::{ - get_db_pool, - import_utils::{add_watch_omega, ensure_omega, ImportMovieOmega}, -}; - -const MOVIE_QUERY: &str = "select * from movies order by random() limit 10000"; +use witch_watch::{get_db_pool, import_utils::add_omega_watches}; #[derive(Debug, Parser)] struct Cli { @@ -32,35 +24,5 @@ async fn main() { let ww_db = get_db_pool().await; - let mut movies: Pin> + Send>> = - sqlx::query_as(MOVIE_QUERY).fetch(&movie_db); - - ensure_omega(&ww_db).await; - - let mut set = JoinSet::new(); - - let retry_strategy = tokio_retry::strategy::ExponentialBackoff::from_millis(100) - .map(tokio_retry::strategy::jitter) - .take(4); - - while let Ok(Some(movie)) = movies.try_next().await { - let db = ww_db.clone(); - let title = movie.title.as_str(); - let year = movie.year.clone().unwrap(); - let len = movie.length.clone().unwrap(); - let retry_strategy = retry_strategy.clone(); - - let key = format!("{title}{year}{len}"); - set.spawn(async move { - ( - key, - Retry::spawn(retry_strategy, || async { - add_watch_omega(&db, &movie).await - }) - .await, - ) - }); - } - // stragglers - while (set.join_next().await).is_some() {} + add_omega_watches(&ww_db, &movie_db).await; } diff --git a/src/bin/import_users.rs b/src/bin/import_users.rs new file mode 100644 index 0000000..20a374d --- /dev/null +++ b/src/bin/import_users.rs @@ -0,0 +1,152 @@ +use std::{ffi::OsString, time::Duration}; + +use clap::Parser; +use rand::{rngs::ThreadRng, seq::SliceRandom, thread_rng, Rng}; +use rand_distr::Normal; +use sqlx::{ + sqlite::{SqliteConnectOptions, SqlitePoolOptions}, + SqlitePool, +}; +use tokio::task::JoinSet; +use tokio_retry::Retry; +use witch_watch::{ + get_db_pool, + import_utils::{add_omega_watches, add_user, add_watch_quests}, + DbId, WatchQuest, +}; + +#[derive(Debug, Parser)] +struct Cli { + /// path to the movie database + #[clap(long = "database", short)] + pub db_path: OsString, + + /// number of users to create + #[clap(long, short, default_value_t = 1000)] + pub users: usize, + + /// expected gaussian value for number of movies per use + #[clap(long = "movies", short, default_value_t = 100)] + pub movies_per_user: u32, + + /// path to the dictionary to be used for usernames [default: + /// /usr/share/dict/words] + #[clap(long, short)] + pub words: Option, +} + +#[tokio::main] +async fn main() { + let cli = Cli::parse(); + let path = cli.db_path; + let num_users = cli.users; + let mpu = cli.movies_per_user as f32; + let dict = if let Some(dict) = cli.words { + dict + } else { + "/usr/share/dict/words".into() + }; + + let words = std::fs::read_to_string(dict).expect("tried to open {dict:?}"); + let words: Vec<&str> = words.split('\n').collect(); + + let opts = SqliteConnectOptions::new().filename(&path).read_only(true); + let movie_db = SqlitePoolOptions::new() + .idle_timeout(Duration::from_secs(90)) + .connect_with(opts) + .await + .expect("could not open movies db"); + let ww_db = get_db_pool().await; + + let users = &gen_users(num_users, &words, &ww_db).await; + let movies = &add_omega_watches(&ww_db, &movie_db).await; + + let rng = &mut thread_rng(); + + let normal = Normal::new(mpu, mpu / 10.0).unwrap(); + for &user in users { + add_quests(user, movies, &ww_db, rng, normal).await; + } +} + +//-************************************************************************ +// add the users +//-************************************************************************ +async fn gen_users(num: usize, words: &[&str], pool: &SqlitePool) -> Vec { + let mut rng = thread_rng(); + let rng = &mut rng; + let range = 0usize..(words.len()); + let mut users = Vec::with_capacity(num); + for _ in 0..num { + let n1 = rng.gen_range(range.clone()); + let n2 = rng.gen_range(range.clone()); + let n3 = rng.gen_range(range.clone()); + let nn = rng.gen_range(0..200); + + let n1 = words[n1].replace('\'', ""); + let n2 = words[n2].replace('\'', ""); + let email_domain = words[n3].replace('\'', ""); + + let username = format!("{n1}_{n2}{nn}"); + let displayname = Some(format!("{n1} {n2}")); + let email = Some(format!("{username}@{email_domain}")); + let id = add_user( + pool, + &username, + displayname.as_deref(), + email.as_deref(), + None, + ) + .await; + users.push(id); + } + + users +} + +//-************************************************************************ +// batch add quests +//-************************************************************************ +async fn add_quests( + user: DbId, + movies: &[DbId], + ww_db: &SqlitePool, + rng: &mut ThreadRng, + normal: Normal, +) { + let mut tasks = JoinSet::new(); + let num_movies = rng.sample(normal) as usize; + let quests: Vec = movies + .choose_multiple(rng, num_movies) + .cloned() + .map(|movie| { + let id = DbId::new(); + WatchQuest { + id, + user, + watch: movie, + is_public: true, + already_watched: false, + } + }) + .collect(); + + let retry_strategy = tokio_retry::strategy::ExponentialBackoff::from_millis(100) + .map(tokio_retry::strategy::jitter) + .take(4); + + let db = ww_db.clone(); + tasks.spawn(async move { + let movies = quests; + ( + user, + Retry::spawn(retry_strategy, || async { + add_watch_quests(&db, &movies).await + }) + .await, + ) + }); + + // get the stragglers + while (tasks.join_next().await).is_some() {} +} diff --git a/src/db.rs b/src/db.rs index 926f3c4..f2ce37d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -30,8 +30,8 @@ pub async fn get_db_pool() -> SqlitePool { } #[cfg(test)] { - use rand_core::RngCore; - let mut rng = rand_core::OsRng; + use rand::RngCore; + let mut rng = rand::thread_rng(); let id = rng.next_u64(); // see https://www.sqlite.org/inmemorydb.html for meaning of the string; // it allows each separate test to have its own dedicated memory-backed db that diff --git a/src/import_utils.rs b/src/import_utils.rs index d10c018..30a7b2c 100644 --- a/src/import_utils.rs +++ b/src/import_utils.rs @@ -1,11 +1,23 @@ +use std::sync::Arc; + use sqlx::{query, query_scalar, SqlitePool}; +use tokio::task::JoinSet; +use tokio_retry::Retry; use crate::{ - db_id::DbId, util::year_to_epoch, watches::handlers::add_new_watch_impl, ShowKind, Watch, + db_id::DbId, + util::year_to_epoch, + watches::handlers::{add_new_watch_impl, add_watch_quest_impl}, + ShowKind, Watch, WatchQuest, }; const USER_EXISTS_QUERY: &str = "select count(*) from witches where id = $1"; +const MOVIE_QUERY: &str = "select * from movies order by random() limit 10000"; + +//-************************************************************************ +// the omega user is the system ID, but has no actual power in the app +//-************************************************************************ const OMEGA_ID: u128 = u128::MAX; #[derive(Debug, sqlx::FromRow, Clone)] @@ -46,24 +58,47 @@ impl From<&ImportMovieOmega> for Watch { //-************************************************************************ // utility functions for building CLI tools, currently just for benchmarking //-************************************************************************ -pub async fn add_watch_omega(db_pool: &SqlitePool, movie: &ImportMovieOmega) -> Result<(), ()> { +pub async fn add_watch_omega(db_pool: &SqlitePool, movie: &ImportMovieOmega) -> Result { let watch: Watch = movie.into(); if add_new_watch_impl(db_pool, &watch, None).await.is_ok() { - println!("{}", watch.id); - Ok(()) + Ok(watch.id) } else { eprintln!("failed to add \"{}\"", watch.title); Err(()) } } +pub async fn add_watch_quest(db_pool: &SqlitePool, quest: WatchQuest) -> Result<(), ()> { + if add_watch_quest_impl(db_pool, &quest).await.is_ok() { + Ok(()) + } else { + eprintln!("failed to add {}", quest.id); + Err(()) + } +} + +pub async fn add_watch_quests(pool: &SqlitePool, quests: &[WatchQuest]) -> Result<(), ()> { + let mut builder = + sqlx::QueryBuilder::new("insert into witch_watch (id, witch, watch, public, watched) "); + builder.push_values(quests.iter(), |mut b, quest| { + b.push_bind(quest.id) + .push_bind(quest.user) + .push_bind(quest.watch) + .push_bind(quest.is_public) + .push_bind(quest.already_watched); + }); + let q = builder.build(); + q.execute(pool).await.map_err(|_| ())?; + Ok(()) +} + pub async fn add_user( db_pool: &SqlitePool, username: &str, displayname: Option<&str>, email: Option<&str>, id: Option, -) { +) -> DbId { let pwhash = "you shall not password"; let id: DbId = id.unwrap_or_else(DbId::new); if query(crate::signup::CREATE_QUERY) @@ -74,12 +109,58 @@ pub async fn add_user( .bind(pwhash) .execute(db_pool) .await - .is_ok() + .is_err() { - println!("{id}"); - } else { eprintln!("failed to add user \"{username}\""); } + id +} + +pub async fn add_omega_watches(ww_db: &SqlitePool, movie_db: &SqlitePool) -> Vec { + ensure_omega(ww_db).await; + + let movies: Vec = sqlx::query_as(MOVIE_QUERY) + .fetch_all(movie_db) + .await + .unwrap(); + + let mut set = JoinSet::new(); + let movie_set = Vec::with_capacity(10_000); + let movie_set = Arc::new(std::sync::Mutex::new(movie_set)); + + let retry_strategy = tokio_retry::strategy::ExponentialBackoff::from_millis(100) + .map(tokio_retry::strategy::jitter) + .take(4); + + for movie in movies { + let db = ww_db.clone(); + let title = movie.title.as_str(); + let year = movie.year.clone().unwrap(); + let len = movie.length.clone().unwrap(); + let retry_strategy = retry_strategy.clone(); + let movie_set = movie_set.clone(); + + let key = format!("{title}{year}{len}"); + set.spawn(async move { + ( + key, + Retry::spawn(retry_strategy, || async { + if let Ok(id) = add_watch_omega(&db, &movie).await { + let mut mset = movie_set.lock().unwrap(); + mset.push(id); + Ok(()) + } else { + Err(()) + } + }) + .await, + ) + }); + } + // stragglers + while (set.join_next().await).is_some() {} + let movies = movie_set.lock().unwrap().clone(); + movies } pub async fn ensure_omega(db_pool: &SqlitePool) -> DbId { @@ -91,7 +172,7 @@ pub async fn ensure_omega(db_pool: &SqlitePool) -> DbId { None, Some(OMEGA_ID.into()), ) - .await + .await; } OMEGA_ID.into() } diff --git a/src/lib.rs b/src/lib.rs index d658431..f1b65b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,8 +42,8 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout use login::{get_login, get_logout, post_login, post_logout}; use signup::{get_create_user, get_signup_success, post_create_user}; use watches::handlers::{ - get_add_new_watch, get_search_watch, get_watch, get_watches, post_add_existing_watch, - post_add_new_watch, + get_add_new_watch, get_search_watch, get_watch, get_watches, post_add_new_watch, + post_add_watch_quest, }; axum::Router::new() @@ -59,7 +59,7 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout .route("/add", get(get_add_new_watch).post(post_add_new_watch)) .route( "/add/watch", - get(get_search_watch).post(post_add_existing_watch), + get(get_search_watch).post(post_add_watch_quest), ) .fallback(handle_slash_redir) .layer(middleware::from_fn_with_state( diff --git a/src/main.rs b/src/main.rs index d9c1062..4c07a0f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use rand_core::{OsRng, RngCore}; +use rand::{thread_rng, RngCore}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use witch_watch::get_db_pool; @@ -18,7 +18,7 @@ async fn main() { let secret = { let mut bytes = [0u8; 64]; - let mut rng = OsRng; + let mut rng = thread_rng(); rng.fill_bytes(&mut bytes); bytes }; diff --git a/src/watches/handlers.rs b/src/watches/handlers.rs index f1fb141..28a38ef 100644 --- a/src/watches/handlers.rs +++ b/src/watches/handlers.rs @@ -194,11 +194,27 @@ pub(crate) async fn add_new_watch_impl( } /// Add a Watch to your watchlist by selecting it with a checkbox -pub async fn post_add_existing_watch( +pub async fn post_add_watch_quest( _auth: AuthContext, State(_pool): State, Form(_form): Form, ) -> impl IntoResponse { + todo!() +} + +pub async fn add_watch_quest_impl(pool: &SqlitePool, quest: &WatchQuest) -> Result<(), ()> { + query(ADD_WITCH_WATCH_QUERY) + .bind(quest.id) + .bind(quest.user) + .bind(quest.watch) + .bind(quest.is_public) + .bind(quest.already_watched) + .execute(pool) + .await + .map_err(|err| { + tracing::error!("Got error: {err}"); + })?; + Ok(()) } /// A single Watch