diff --git a/Cargo.lock b/Cargo.lock index b96a0ca..56ccb4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1293,6 +1293,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1503,6 +1504,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -2528,6 +2539,7 @@ dependencies = [ "optional_optional_user", "password-hash", "rand", + "rand_distr", "serde", "serde_test", "sqlx", diff --git a/Cargo.toml b/Cargo.toml index 43a7758..c218213 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } ulid = { version = "1", features = ["rand"] } unicode-segmentation = "1" +rand_distr = "0.4.3" [dev-dependencies] axum-test = "9.0.0" diff --git a/src/bin/import_omega.rs b/src/bin/import_omega.rs index 7eaf2c1..268cb3d 100644 --- a/src/bin/import_omega.rs +++ b/src/bin/import_omega.rs @@ -2,12 +2,7 @@ use std::{ffi::OsString, time::Duration}; use clap::Parser; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; -use witch_watch::{ - get_db_pool, - import_utils::{add_omega_watches, ImportMovieOmega}, -}; - -const MOVIE_QUERY: &str = "select * from movies order by random() limit 10000"; +use witch_watch::{get_db_pool, import_utils::add_omega_watches}; #[derive(Debug, Parser)] struct Cli { @@ -29,10 +24,5 @@ async fn main() { let ww_db = get_db_pool().await; - let movies: Vec = sqlx::query_as(MOVIE_QUERY) - .fetch_all(&movie_db) - .await - .unwrap(); - - add_omega_watches(&ww_db, movies).await; + add_omega_watches(&ww_db, &movie_db).await; } diff --git a/src/bin/import_users.rs b/src/bin/import_users.rs index be6f968..f2290fe 100644 --- a/src/bin/import_users.rs +++ b/src/bin/import_users.rs @@ -1,22 +1,20 @@ -use std::{ffi::OsString, time::Duration}; +use std::{collections::BTreeSet, ffi::OsString, time::Duration}; use clap::Parser; use rand::{thread_rng, Rng}; +use rand_distr::Normal; use sqlx::{ sqlite::{SqliteConnectOptions, SqlitePoolOptions}, - FromRow, SqlitePool, + SqlitePool, }; use tokio::task::JoinSet; use tokio_retry::Retry; -use tokio_stream::{Stream, StreamExt}; use witch_watch::{ get_db_pool, - import_utils::{add_user, add_watch_user, ImportMovieOmega}, - DbId, User, Watch, WatchQuest, + import_utils::{add_omega_watches, add_user, add_watch_quest}, + DbId, WatchQuest, }; -const MOVIE_QUERY: &str = "select * from movies order by random() limit ?"; - #[derive(Debug, Parser)] struct Cli { /// path to the movie database @@ -42,7 +40,7 @@ async fn main() { let cli = Cli::parse(); let path = cli.db_path; let num_users = cli.users; - let mpu = cli.movies_per_user; + let mpu = cli.movies_per_user as f32; let dict = if let Some(dict) = cli.words { dict } else { @@ -58,46 +56,47 @@ async fn main() { .connect_with(opts) .await .expect("could not open movies db"); - let ww_db = get_db_pool().await; + let users = &gen_users(num_users, &words, &ww_db).await; + let movies = &add_omega_watches(&ww_db, &movie_db).await; + + let normal = Normal::new(mpu, mpu / 10.0).unwrap(); let rng = &mut thread_rng(); + for user in users { + let mut joinset = JoinSet::new(); - let users = gen_users(num_users, &words, &ww_db).await; + let mut mset = BTreeSet::new(); + let num_movies = rng.sample(normal) as usize; - for _ in 0..num_users { - let mut movies = sqlx::query(MOVIE_QUERY).bind(mpu).fetch(&movie_db); - - let mut set = JoinSet::new(); + while mset.len() < num_movies { + let idx = rng.gen_range(0..10_000usize); + mset.insert(idx); + } + dbg!("done with mset pop"); let retry_strategy = tokio_retry::strategy::ExponentialBackoff::from_millis(100) .map(tokio_retry::strategy::jitter) .take(4); - while let Ok(Some(movie)) = movies.try_next().await { - let movie = ImportMovieOmega::from_row(&movie).unwrap(); - let mut watch: Watch = movie.into(); - let db = ww_db.clone(); - let retry_strategy = retry_strategy.clone(); - - let user = rng.gen_range(0..num_users); - let user = users[user]; - watch.added_by = user; - + for movie in mset.iter() { + let movie = movies[*movie]; let quest = WatchQuest { id: DbId::new(), - user, - watch: watch.id, + user: *user, + watch: movie, is_public: true, already_watched: false, }; + let retry_strategy = retry_strategy.clone(); + let db = ww_db.clone(); let key = quest.id.as_string(); - set.spawn(async move { + joinset.spawn(async move { ( key, Retry::spawn(retry_strategy, || async { - add_watch_user(&db, &watch, quest).await + add_watch_quest(&db, quest).await }) .await, ) @@ -105,7 +104,7 @@ async fn main() { } // stragglers - while (set.join_next().await).is_some() {} + while (joinset.join_next().await).is_some() {} } } diff --git a/src/import_utils.rs b/src/import_utils.rs index 884d3f8..730d2f7 100644 --- a/src/import_utils.rs +++ b/src/import_utils.rs @@ -1,14 +1,20 @@ +use std::sync::{Arc, Mutex}; + use sqlx::{query, query_scalar, SqlitePool}; use tokio::task::JoinSet; use tokio_retry::Retry; use crate::{ - db_id::DbId, util::year_to_epoch, watches::handlers::add_new_watch_impl, ShowKind, Watch, - WatchQuest, + db_id::DbId, + util::year_to_epoch, + watches::handlers::{add_new_watch_impl, add_watch_quest_impl}, + ShowKind, Watch, WatchQuest, }; const USER_EXISTS_QUERY: &str = "select count(*) from witches where id = $1"; +const MOVIE_QUERY: &str = "select * from movies order by random() limit 10000"; + //-************************************************************************ // the omega user is the system ID, but has no actual power in the app //-************************************************************************ @@ -52,30 +58,21 @@ impl From<&ImportMovieOmega> for Watch { //-************************************************************************ // utility functions for building CLI tools, currently just for benchmarking //-************************************************************************ -pub async fn add_watch_omega(db_pool: &SqlitePool, movie: &ImportMovieOmega) -> Result<(), ()> { +pub async fn add_watch_omega(db_pool: &SqlitePool, movie: &ImportMovieOmega) -> Result { let watch: Watch = movie.into(); if add_new_watch_impl(db_pool, &watch, None).await.is_ok() { - println!("{}", watch.id); - Ok(()) + Ok(watch.id) } else { eprintln!("failed to add \"{}\"", watch.title); Err(()) } } -pub async fn add_watch_user( - db_pool: &SqlitePool, - watch: &Watch, - quest: WatchQuest, -) -> Result<(), ()> { - if add_new_watch_impl(db_pool, watch, Some(quest)) - .await - .is_ok() - { - println!("{}", watch.id); +pub async fn add_watch_quest(db_pool: &SqlitePool, quest: WatchQuest) -> Result<(), ()> { + if add_watch_quest_impl(db_pool, &quest).await.is_ok() { Ok(()) } else { - eprintln!("failed to add \"{}\"", watch.title); + eprintln!("failed to add {}", quest.id); Err(()) } } @@ -105,10 +102,17 @@ pub async fn add_user( } } -pub async fn add_omega_watches(ww_db: &SqlitePool, movies: Vec) { +pub async fn add_omega_watches(ww_db: &SqlitePool, movie_db: &SqlitePool) -> Vec { ensure_omega(ww_db).await; + let movies: Vec = sqlx::query_as(MOVIE_QUERY) + .fetch_all(movie_db) + .await + .unwrap(); + let mut set = JoinSet::new(); + let movie_set = Vec::with_capacity(10_000); + let movie_set = Arc::new(Mutex::new(movie_set)); let retry_strategy = tokio_retry::strategy::ExponentialBackoff::from_millis(100) .map(tokio_retry::strategy::jitter) @@ -120,13 +124,20 @@ pub async fn add_omega_watches(ww_db: &SqlitePool, movies: Vec let year = movie.year.clone().unwrap(); let len = movie.length.clone().unwrap(); let retry_strategy = retry_strategy.clone(); + let movie_set = movie_set.clone(); let key = format!("{title}{year}{len}"); set.spawn(async move { ( key, Retry::spawn(retry_strategy, || async { - add_watch_omega(&db, &movie).await + if let Ok(id) = add_watch_omega(&db, &movie).await { + let mut mset = movie_set.lock().unwrap(); + mset.push(id); + Ok(()) + } else { + Err(()) + } }) .await, ) @@ -134,6 +145,8 @@ pub async fn add_omega_watches(ww_db: &SqlitePool, movies: Vec } // stragglers while (set.join_next().await).is_some() {} + let movies = movie_set.lock().unwrap().clone(); + movies } pub async fn ensure_omega(db_pool: &SqlitePool) -> DbId { diff --git a/src/lib.rs b/src/lib.rs index d658431..f1b65b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,8 +42,8 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout use login::{get_login, get_logout, post_login, post_logout}; use signup::{get_create_user, get_signup_success, post_create_user}; use watches::handlers::{ - get_add_new_watch, get_search_watch, get_watch, get_watches, post_add_existing_watch, - post_add_new_watch, + get_add_new_watch, get_search_watch, get_watch, get_watches, post_add_new_watch, + post_add_watch_quest, }; axum::Router::new() @@ -59,7 +59,7 @@ pub async fn app(db_pool: sqlx::SqlitePool, session_secret: &[u8]) -> axum::Rout .route("/add", get(get_add_new_watch).post(post_add_new_watch)) .route( "/add/watch", - get(get_search_watch).post(post_add_existing_watch), + get(get_search_watch).post(post_add_watch_quest), ) .fallback(handle_slash_redir) .layer(middleware::from_fn_with_state( diff --git a/src/watches/handlers.rs b/src/watches/handlers.rs index f1fb141..28a38ef 100644 --- a/src/watches/handlers.rs +++ b/src/watches/handlers.rs @@ -194,11 +194,27 @@ pub(crate) async fn add_new_watch_impl( } /// Add a Watch to your watchlist by selecting it with a checkbox -pub async fn post_add_existing_watch( +pub async fn post_add_watch_quest( _auth: AuthContext, State(_pool): State, Form(_form): Form, ) -> impl IntoResponse { + todo!() +} + +pub async fn add_watch_quest_impl(pool: &SqlitePool, quest: &WatchQuest) -> Result<(), ()> { + query(ADD_WITCH_WATCH_QUERY) + .bind(quest.id) + .bind(quest.user) + .bind(quest.watch) + .bind(quest.is_public) + .bind(quest.already_watched) + .execute(pool) + .await + .map_err(|err| { + tracing::error!("Got error: {err}"); + })?; + Ok(()) } /// A single Watch