diff --git a/src/discovery.rs b/src/discovery.rs index d1eb83d..1584d7c 100644 --- a/src/discovery.rs +++ b/src/discovery.rs @@ -9,7 +9,7 @@ use axum::{ }; use tokio::net::UdpSocket; -use crate::{Config, JoecalState, RunningState, models::device::DeviceInfo}; +use crate::{Config, JoecalState, RunningState, models::Device}; impl JoecalState { pub async fn announce( @@ -48,7 +48,7 @@ impl JoecalState { } async fn process_device(&self, message: &str, src: SocketAddr, config: &Config) { - if let Ok(device) = serde_json::from_str::(message) { + if let Ok(device) = serde_json::from_str::(message) { if device.fingerprint == self.device.fingerprint { return; } @@ -83,8 +83,8 @@ impl JoecalState { pub async fn register_device( State(state): State, ConnectInfo(addr): ConnectInfo, - Json(device): Json, -) -> Json { + Json(device): Json, +) -> Json { let mut addr = addr; addr.set_port(state.device.port); state @@ -99,9 +99,9 @@ pub async fn register_device( // private helpers //-************************************************************************ async fn announce_http( - device: &DeviceInfo, + device: &Device, ip: Option, - client: Arc, + client: reqwest::Client, ) -> crate::error::Result<()> { if let Some(ip) = ip { let url = format!("http://{ip}/api/localsend/v2/register"); @@ -111,7 +111,7 @@ async fn announce_http( } async fn announce_multicast( - device: &DeviceInfo, + device: &Device, addr: SocketAddrV4, socket: Arc, ) -> crate::error::Result<()> { diff --git a/src/http_server.rs b/src/http_server.rs index e89a50b..0fdc72c 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -5,7 +5,7 @@ use axum::{ extract::DefaultBodyLimit, routing::{get, post}, }; -use tokio::net::TcpListener; +use tokio::{net::TcpListener, sync::mpsc}; use tower_http::limit::RequestBodyLimitLayer; use crate::{ @@ -15,8 +15,12 @@ use crate::{ }; impl JoecalState { - pub async fn start_http_server(&self, config: &Config) -> crate::error::Result<()> { - let app = self.create_router(&config); + pub async fn start_http_server( + &self, + stop_rx: mpsc::Receiver<()>, + config: &Config, + ) -> crate::error::Result<()> { + let app = self.create_router(config); let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); let listener = TcpListener::bind(&addr).await?; @@ -26,6 +30,7 @@ impl JoecalState { listener, app.into_make_service_with_connect_info::(), ) + .with_graceful_shutdown(shutdown(stop_rx)) .await?; Ok(()) } @@ -52,3 +57,8 @@ impl JoecalState { .with_state(self.clone()) } } + +async fn shutdown(mut rx: mpsc::Receiver<()>) { + println!("shutting down"); + rx.recv().await.unwrap_or_default() +} diff --git a/src/lib.rs b/src/lib.rs index 117eda6..2bfa2d0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,11 +8,16 @@ use std::{ collections::HashMap, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::Arc, + time::Duration, }; -use models::device::DeviceInfo; +use models::Device; use serde::{Deserialize, Serialize}; -use tokio::{net::UdpSocket, sync::Mutex, task::JoinHandle}; +use tokio::{ + net::UdpSocket, + sync::{Mutex, mpsc}, + task::JoinHandle, +}; use transfer::Session; pub const DEFAULT_PORT: u16 = 53317; @@ -20,18 +25,20 @@ pub const MULTICAST_IP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 167); pub const LISTENING_SOCKET_ADDR: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::from_bits(0), DEFAULT_PORT); +/// Contains the main network and application state for an application session. #[derive(Clone)] pub struct JoecalState { - pub device: DeviceInfo, - pub peers: Arc>>, + pub device: Device, + pub peers: Arc>>, pub sessions: Arc>>, // Session ID to Session pub running_state: Arc>, pub socket: Arc, - pub client: Arc, + pub client: reqwest::Client, + stop_tx: std::sync::OnceLock>, } impl JoecalState { - pub async fn new(device: DeviceInfo) -> crate::error::Result { + pub async fn new(device: Device) -> crate::error::Result { let socket = UdpSocket::bind(LISTENING_SOCKET_ADDR).await?; socket.set_multicast_loop_v4(true)?; socket.set_multicast_ttl_v4(2)?; // one hop out from localnet @@ -43,45 +50,11 @@ impl JoecalState { sessions: Default::default(), running_state: Default::default(), socket: socket.into(), - client: reqwest::Client::new().into(), + client: reqwest::Client::new(), + stop_tx: Default::default(), }) } -} -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum RunningState { - Running, - Sending, - Receiving, - Stopping, -} - -impl Default for RunningState { - fn default() -> Self { - Self::Running - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Config { - pub multicast_addr: SocketAddrV4, - pub port: u16, - pub download_dir: String, -} - -impl Default for Config { - fn default() -> Self { - let home = std::env::home_dir().unwrap_or("/tmp".into()); - let dd = home.join("joecalsend-downloads"); - Self { - multicast_addr: SocketAddrV4::new(MULTICAST_IP, DEFAULT_PORT), - port: DEFAULT_PORT, - download_dir: dd.to_string_lossy().into(), - } - } -} - -impl JoecalState { pub async fn start( &self, config: &Config, @@ -89,8 +62,10 @@ impl JoecalState { let state = self.clone(); let konfig = config.clone(); let server_handle = { + let (tx, rx) = mpsc::channel(1); + self.stop_tx.get_or_init(|| tx); tokio::spawn(async move { - if let Err(e) = state.start_http_server(&konfig).await { + if let Err(e) = state.start_http_server(rx, &konfig).await { eprintln!("HTTP server error: {e}"); } }) @@ -126,8 +101,59 @@ impl JoecalState { Ok((server_handle, udp_handle, announcement_handle)) } + pub async fn stop(&self) { + loop { + if let Ok(mut lock) = self.running_state.try_lock() { + *lock = RunningState::Stopping; + if self + .stop_tx + .get() + .expect("Could not get stop signal transmitter") + .send(()) + .await + .is_ok() + { + break; + } + } + } + } + pub async fn refresh_peers(&self) { let mut peers = self.peers.lock().await; peers.clear(); } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RunningState { + Running, + Sending, + Receiving, + Stopping, +} + +impl Default for RunningState { + fn default() -> Self { + Self::Running + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + pub multicast_addr: SocketAddrV4, + pub port: u16, + pub download_dir: String, +} + +impl Default for Config { + fn default() -> Self { + let home = std::env::home_dir().unwrap_or("/tmp".into()); + let dd = home.join("joecalsend-downloads"); + Self { + multicast_addr: SocketAddrV4::new(MULTICAST_IP, DEFAULT_PORT), + port: DEFAULT_PORT, + download_dir: dd.to_string_lossy().into(), + } + } +} diff --git a/src/main.rs b/src/main.rs index ff62d58..0be89e3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use std::io; -use joecalsend::{Config, JoecalState, RunningState, error, models::device::DeviceInfo}; +use joecalsend::{Config, JoecalState, RunningState, error, models::Device}; use local_ip_address::local_ip; use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig, V4IfAddr}; use ratatui::{ @@ -16,7 +16,7 @@ use ratatui::{ #[tokio::main] async fn main() -> error::Result<()> { - let device = DeviceInfo::default(); + let device = Device::default(); dbg!(&device); let std::net::IpAddr::V4(ip) = local_ip()? else { @@ -48,9 +48,9 @@ async fn main() -> error::Result<()> { let mut app = App::new(state.clone()); let mut terminal = ratatui::init(); - let result = app.run(&mut terminal); + let result = app.run(&mut terminal).await; ratatui::restore(); - //let _ = tokio::join!(h1, h2, h3); + let _ = tokio::join!(h1, h2, h3); Ok(result?) } @@ -64,10 +64,10 @@ impl App { App { state } } - pub fn run(&mut self, terminal: &mut DefaultTerminal) -> io::Result<()> { + pub async fn run(&mut self, terminal: &mut DefaultTerminal) -> io::Result<()> { loop { terminal.draw(|frame| self.draw(frame))?; - self.handle_events()?; + self.handle_events().await?; if let Ok(lock) = self.state.running_state.try_lock() && *lock == RunningState::Stopping { @@ -81,21 +81,21 @@ impl App { frame.render_widget(self, frame.area()); } - fn handle_events(&mut self) -> io::Result<()> { + async fn handle_events(&mut self) -> io::Result<()> { match event::read()? { // it's important to check that the event is a key press event as // crossterm also emits key release and repeat events on Windows. Event::Key(key_event) if key_event.kind == KeyEventKind::Press => { - self.handle_key_event(key_event) + self.handle_key_event(key_event).await } _ => {} }; Ok(()) } - fn handle_key_event(&mut self, key_event: KeyEvent) { + async fn handle_key_event(&mut self, key_event: KeyEvent) { match key_event.code { - KeyCode::Char('q') => self.exit(), + KeyCode::Char('q') => self.exit().await, KeyCode::Char('s') => {} KeyCode::Char('r') => {} KeyCode::Char('d') => {} @@ -103,13 +103,8 @@ impl App { } } - fn exit(&mut self) { - loop { - if let Ok(mut lock) = self.state.running_state.try_lock() { - *lock = RunningState::Stopping; - break; - } - } + async fn exit(&self) { + self.state.stop().await; } } diff --git a/src/models/file.rs b/src/models.rs similarity index 59% rename from src/models/file.rs rename to src/models.rs index 9be7818..0e80b5c 100644 --- a/src/models/file.rs +++ b/src/models.rs @@ -1,6 +1,7 @@ use std::{path::Path, time::SystemTime}; use chrono::{DateTime, Utc}; +use julid::Julid; use serde::{Deserialize, Serialize}; use crate::error::LocalSendError; @@ -63,6 +64,64 @@ impl FileMetadata { } } +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum DeviceType { + Mobile, + Desktop, + Web, + Headless, + Server, + Unknown, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Device { + pub alias: String, + pub version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub device_model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub device_type: Option, + pub fingerprint: String, + pub port: u16, + pub protocol: String, + #[serde(default)] + pub download: bool, + #[serde(default)] + pub announce: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum Protocol { + Http, + Https, +} + +impl Default for Device { + fn default() -> Self { + Self { + alias: "RustSend".to_string(), + version: "2.1".to_string(), + device_model: None, + device_type: Some(DeviceType::Headless), + fingerprint: Julid::new().to_string(), + port: 53317, + protocol: "http".to_string(), + download: false, + announce: Some(true), + } + } +} + +impl Device { + pub fn to_json(&self) -> crate::error::Result { + Ok(serde_json::to_string(self)?) + } +} + fn format_datetime(system_time: SystemTime) -> String { let datetime: DateTime = system_time.into(); datetime.to_rfc3339() diff --git a/src/models/device.rs b/src/models/device.rs deleted file mode 100644 index f728f6b..0000000 --- a/src/models/device.rs +++ /dev/null @@ -1,60 +0,0 @@ -use julid::Julid; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum DeviceType { - Mobile, - Desktop, - Web, - Headless, - Server, - Unknown, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct DeviceInfo { - pub alias: String, - pub version: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub device_model: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub device_type: Option, - pub fingerprint: String, - pub port: u16, - pub protocol: String, - #[serde(default)] - pub download: bool, - #[serde(default)] - pub announce: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum Protocol { - Http, - Https, -} - -impl Default for DeviceInfo { - fn default() -> Self { - Self { - alias: "RustSend".to_string(), - version: "2.1".to_string(), - device_model: None, - device_type: Some(DeviceType::Headless), - fingerprint: Julid::new().to_string(), - port: 53317, - protocol: "http".to_string(), - download: false, - announce: Some(true), - } - } -} - -impl DeviceInfo { - pub fn to_json(&self) -> crate::error::Result { - Ok(serde_json::to_string(self)?) - } -} diff --git a/src/models/mod.rs b/src/models/mod.rs deleted file mode 100644 index 5a9d6eb..0000000 --- a/src/models/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod device; -pub mod file; -pub mod session; diff --git a/src/models/session.rs b/src/models/session.rs deleted file mode 100644 index 8b13789..0000000 --- a/src/models/session.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/transfer.rs b/src/transfer.rs index 9c4f8bc..de88f30 100644 --- a/src/transfer.rs +++ b/src/transfer.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; use crate::{ JoecalState, error::{LocalSendError, Result}, - models::{device::DeviceInfo, file::FileMetadata}, + models::{Device, FileMetadata}, }; #[derive(Deserialize, Serialize)] @@ -22,8 +22,8 @@ pub struct Session { pub session_id: String, pub files: HashMap, pub file_tokens: HashMap, - pub receiver: DeviceInfo, - pub sender: DeviceInfo, + pub receiver: Device, + pub sender: Device, pub status: SessionStatus, pub addr: SocketAddr, } @@ -47,7 +47,7 @@ pub struct PrepareUploadResponse { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PrepareUploadRequest { - pub info: DeviceInfo, + pub info: Device, pub files: HashMap, }