try to shutdown nicely

This commit is contained in:
Joe Ardent 2025-07-06 16:02:11 -07:00
parent f5af07f860
commit 342b634388
9 changed files with 164 additions and 138 deletions

View file

@ -9,7 +9,7 @@ use axum::{
}; };
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::{Config, JoecalState, RunningState, models::device::DeviceInfo}; use crate::{Config, JoecalState, RunningState, models::Device};
impl JoecalState { impl JoecalState {
pub async fn announce( pub async fn announce(
@ -48,7 +48,7 @@ impl JoecalState {
} }
async fn process_device(&self, message: &str, src: SocketAddr, config: &Config) { async fn process_device(&self, message: &str, src: SocketAddr, config: &Config) {
if let Ok(device) = serde_json::from_str::<DeviceInfo>(message) { if let Ok(device) = serde_json::from_str::<Device>(message) {
if device.fingerprint == self.device.fingerprint { if device.fingerprint == self.device.fingerprint {
return; return;
} }
@ -83,8 +83,8 @@ impl JoecalState {
pub async fn register_device( pub async fn register_device(
State(state): State<JoecalState>, State(state): State<JoecalState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(device): Json<DeviceInfo>, Json(device): Json<Device>,
) -> Json<DeviceInfo> { ) -> Json<Device> {
let mut addr = addr; let mut addr = addr;
addr.set_port(state.device.port); addr.set_port(state.device.port);
state state
@ -99,9 +99,9 @@ pub async fn register_device(
// private helpers // private helpers
//-************************************************************************ //-************************************************************************
async fn announce_http( async fn announce_http(
device: &DeviceInfo, device: &Device,
ip: Option<SocketAddr>, ip: Option<SocketAddr>,
client: Arc<reqwest::Client>, client: reqwest::Client,
) -> crate::error::Result<()> { ) -> crate::error::Result<()> {
if let Some(ip) = ip { if let Some(ip) = ip {
let url = format!("http://{ip}/api/localsend/v2/register"); let url = format!("http://{ip}/api/localsend/v2/register");
@ -111,7 +111,7 @@ async fn announce_http(
} }
async fn announce_multicast( async fn announce_multicast(
device: &DeviceInfo, device: &Device,
addr: SocketAddrV4, addr: SocketAddrV4,
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
) -> crate::error::Result<()> { ) -> crate::error::Result<()> {

View file

@ -5,7 +5,7 @@ use axum::{
extract::DefaultBodyLimit, extract::DefaultBodyLimit,
routing::{get, post}, routing::{get, post},
}; };
use tokio::net::TcpListener; use tokio::{net::TcpListener, sync::mpsc};
use tower_http::limit::RequestBodyLimitLayer; use tower_http::limit::RequestBodyLimitLayer;
use crate::{ use crate::{
@ -15,8 +15,12 @@ use crate::{
}; };
impl JoecalState { impl JoecalState {
pub async fn start_http_server(&self, config: &Config) -> crate::error::Result<()> { pub async fn start_http_server(
let app = self.create_router(&config); &self,
stop_rx: mpsc::Receiver<()>,
config: &Config,
) -> crate::error::Result<()> {
let app = self.create_router(config);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
let listener = TcpListener::bind(&addr).await?; let listener = TcpListener::bind(&addr).await?;
@ -26,6 +30,7 @@ impl JoecalState {
listener, listener,
app.into_make_service_with_connect_info::<SocketAddr>(), app.into_make_service_with_connect_info::<SocketAddr>(),
) )
.with_graceful_shutdown(shutdown(stop_rx))
.await?; .await?;
Ok(()) Ok(())
} }
@ -52,3 +57,8 @@ impl JoecalState {
.with_state(self.clone()) .with_state(self.clone())
} }
} }
async fn shutdown(mut rx: mpsc::Receiver<()>) {
println!("shutting down");
rx.recv().await.unwrap_or_default()
}

View file

@ -8,11 +8,16 @@ use std::{
collections::HashMap, collections::HashMap,
net::{Ipv4Addr, SocketAddr, SocketAddrV4}, net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc, sync::Arc,
time::Duration,
}; };
use models::device::DeviceInfo; use models::Device;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::{net::UdpSocket, sync::Mutex, task::JoinHandle}; use tokio::{
net::UdpSocket,
sync::{Mutex, mpsc},
task::JoinHandle,
};
use transfer::Session; use transfer::Session;
pub const DEFAULT_PORT: u16 = 53317; pub const DEFAULT_PORT: u16 = 53317;
@ -20,18 +25,20 @@ pub const MULTICAST_IP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 167);
pub const LISTENING_SOCKET_ADDR: SocketAddrV4 = pub const LISTENING_SOCKET_ADDR: SocketAddrV4 =
SocketAddrV4::new(Ipv4Addr::from_bits(0), DEFAULT_PORT); SocketAddrV4::new(Ipv4Addr::from_bits(0), DEFAULT_PORT);
/// Contains the main network and application state for an application session.
#[derive(Clone)] #[derive(Clone)]
pub struct JoecalState { pub struct JoecalState {
pub device: DeviceInfo, pub device: Device,
pub peers: Arc<Mutex<HashMap<String, (SocketAddr, DeviceInfo)>>>, pub peers: Arc<Mutex<HashMap<String, (SocketAddr, Device)>>>,
pub sessions: Arc<Mutex<HashMap<String, Session>>>, // Session ID to Session pub sessions: Arc<Mutex<HashMap<String, Session>>>, // Session ID to Session
pub running_state: Arc<Mutex<RunningState>>, pub running_state: Arc<Mutex<RunningState>>,
pub socket: Arc<UdpSocket>, pub socket: Arc<UdpSocket>,
pub client: Arc<reqwest::Client>, pub client: reqwest::Client,
stop_tx: std::sync::OnceLock<mpsc::Sender<()>>,
} }
impl JoecalState { impl JoecalState {
pub async fn new(device: DeviceInfo) -> crate::error::Result<Self> { pub async fn new(device: Device) -> crate::error::Result<Self> {
let socket = UdpSocket::bind(LISTENING_SOCKET_ADDR).await?; let socket = UdpSocket::bind(LISTENING_SOCKET_ADDR).await?;
socket.set_multicast_loop_v4(true)?; socket.set_multicast_loop_v4(true)?;
socket.set_multicast_ttl_v4(2)?; // one hop out from localnet socket.set_multicast_ttl_v4(2)?; // one hop out from localnet
@ -43,45 +50,11 @@ impl JoecalState {
sessions: Default::default(), sessions: Default::default(),
running_state: Default::default(), running_state: Default::default(),
socket: socket.into(), socket: socket.into(),
client: reqwest::Client::new().into(), client: reqwest::Client::new(),
stop_tx: Default::default(),
}) })
} }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RunningState {
Running,
Sending,
Receiving,
Stopping,
}
impl Default for RunningState {
fn default() -> Self {
Self::Running
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub multicast_addr: SocketAddrV4,
pub port: u16,
pub download_dir: String,
}
impl Default for Config {
fn default() -> Self {
let home = std::env::home_dir().unwrap_or("/tmp".into());
let dd = home.join("joecalsend-downloads");
Self {
multicast_addr: SocketAddrV4::new(MULTICAST_IP, DEFAULT_PORT),
port: DEFAULT_PORT,
download_dir: dd.to_string_lossy().into(),
}
}
}
impl JoecalState {
pub async fn start( pub async fn start(
&self, &self,
config: &Config, config: &Config,
@ -89,8 +62,10 @@ impl JoecalState {
let state = self.clone(); let state = self.clone();
let konfig = config.clone(); let konfig = config.clone();
let server_handle = { let server_handle = {
let (tx, rx) = mpsc::channel(1);
self.stop_tx.get_or_init(|| tx);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = state.start_http_server(&konfig).await { if let Err(e) = state.start_http_server(rx, &konfig).await {
eprintln!("HTTP server error: {e}"); eprintln!("HTTP server error: {e}");
} }
}) })
@ -126,8 +101,59 @@ impl JoecalState {
Ok((server_handle, udp_handle, announcement_handle)) Ok((server_handle, udp_handle, announcement_handle))
} }
pub async fn stop(&self) {
loop {
if let Ok(mut lock) = self.running_state.try_lock() {
*lock = RunningState::Stopping;
if self
.stop_tx
.get()
.expect("Could not get stop signal transmitter")
.send(())
.await
.is_ok()
{
break;
}
}
}
}
pub async fn refresh_peers(&self) { pub async fn refresh_peers(&self) {
let mut peers = self.peers.lock().await; let mut peers = self.peers.lock().await;
peers.clear(); peers.clear();
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RunningState {
Running,
Sending,
Receiving,
Stopping,
}
impl Default for RunningState {
fn default() -> Self {
Self::Running
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub multicast_addr: SocketAddrV4,
pub port: u16,
pub download_dir: String,
}
impl Default for Config {
fn default() -> Self {
let home = std::env::home_dir().unwrap_or("/tmp".into());
let dd = home.join("joecalsend-downloads");
Self {
multicast_addr: SocketAddrV4::new(MULTICAST_IP, DEFAULT_PORT),
port: DEFAULT_PORT,
download_dir: dd.to_string_lossy().into(),
}
}
}

View file

@ -1,6 +1,6 @@
use std::io; use std::io;
use joecalsend::{Config, JoecalState, RunningState, error, models::device::DeviceInfo}; use joecalsend::{Config, JoecalState, RunningState, error, models::Device};
use local_ip_address::local_ip; use local_ip_address::local_ip;
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig, V4IfAddr}; use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig, V4IfAddr};
use ratatui::{ use ratatui::{
@ -16,7 +16,7 @@ use ratatui::{
#[tokio::main] #[tokio::main]
async fn main() -> error::Result<()> { async fn main() -> error::Result<()> {
let device = DeviceInfo::default(); let device = Device::default();
dbg!(&device); dbg!(&device);
let std::net::IpAddr::V4(ip) = local_ip()? else { let std::net::IpAddr::V4(ip) = local_ip()? else {
@ -48,9 +48,9 @@ async fn main() -> error::Result<()> {
let mut app = App::new(state.clone()); let mut app = App::new(state.clone());
let mut terminal = ratatui::init(); let mut terminal = ratatui::init();
let result = app.run(&mut terminal); let result = app.run(&mut terminal).await;
ratatui::restore(); ratatui::restore();
//let _ = tokio::join!(h1, h2, h3); let _ = tokio::join!(h1, h2, h3);
Ok(result?) Ok(result?)
} }
@ -64,10 +64,10 @@ impl App {
App { state } App { state }
} }
pub fn run(&mut self, terminal: &mut DefaultTerminal) -> io::Result<()> { pub async fn run(&mut self, terminal: &mut DefaultTerminal) -> io::Result<()> {
loop { loop {
terminal.draw(|frame| self.draw(frame))?; terminal.draw(|frame| self.draw(frame))?;
self.handle_events()?; self.handle_events().await?;
if let Ok(lock) = self.state.running_state.try_lock() if let Ok(lock) = self.state.running_state.try_lock()
&& *lock == RunningState::Stopping && *lock == RunningState::Stopping
{ {
@ -81,21 +81,21 @@ impl App {
frame.render_widget(self, frame.area()); frame.render_widget(self, frame.area());
} }
fn handle_events(&mut self) -> io::Result<()> { async fn handle_events(&mut self) -> io::Result<()> {
match event::read()? { match event::read()? {
// it's important to check that the event is a key press event as // it's important to check that the event is a key press event as
// crossterm also emits key release and repeat events on Windows. // crossterm also emits key release and repeat events on Windows.
Event::Key(key_event) if key_event.kind == KeyEventKind::Press => { Event::Key(key_event) if key_event.kind == KeyEventKind::Press => {
self.handle_key_event(key_event) self.handle_key_event(key_event).await
} }
_ => {} _ => {}
}; };
Ok(()) Ok(())
} }
fn handle_key_event(&mut self, key_event: KeyEvent) { async fn handle_key_event(&mut self, key_event: KeyEvent) {
match key_event.code { match key_event.code {
KeyCode::Char('q') => self.exit(), KeyCode::Char('q') => self.exit().await,
KeyCode::Char('s') => {} KeyCode::Char('s') => {}
KeyCode::Char('r') => {} KeyCode::Char('r') => {}
KeyCode::Char('d') => {} KeyCode::Char('d') => {}
@ -103,13 +103,8 @@ impl App {
} }
} }
fn exit(&mut self) { async fn exit(&self) {
loop { self.state.stop().await;
if let Ok(mut lock) = self.state.running_state.try_lock() {
*lock = RunningState::Stopping;
break;
}
}
} }
} }

View file

@ -1,6 +1,7 @@
use std::{path::Path, time::SystemTime}; use std::{path::Path, time::SystemTime};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use julid::Julid;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::LocalSendError; use crate::error::LocalSendError;
@ -63,6 +64,64 @@ impl FileMetadata {
} }
} }
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum DeviceType {
Mobile,
Desktop,
Web,
Headless,
Server,
Unknown,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Device {
pub alias: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_type: Option<DeviceType>,
pub fingerprint: String,
pub port: u16,
pub protocol: String,
#[serde(default)]
pub download: bool,
#[serde(default)]
pub announce: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Protocol {
Http,
Https,
}
impl Default for Device {
fn default() -> Self {
Self {
alias: "RustSend".to_string(),
version: "2.1".to_string(),
device_model: None,
device_type: Some(DeviceType::Headless),
fingerprint: Julid::new().to_string(),
port: 53317,
protocol: "http".to_string(),
download: false,
announce: Some(true),
}
}
}
impl Device {
pub fn to_json(&self) -> crate::error::Result<String> {
Ok(serde_json::to_string(self)?)
}
}
fn format_datetime(system_time: SystemTime) -> String { fn format_datetime(system_time: SystemTime) -> String {
let datetime: DateTime<Utc> = system_time.into(); let datetime: DateTime<Utc> = system_time.into();
datetime.to_rfc3339() datetime.to_rfc3339()

View file

@ -1,60 +0,0 @@
use julid::Julid;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum DeviceType {
Mobile,
Desktop,
Web,
Headless,
Server,
Unknown,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct DeviceInfo {
pub alias: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_type: Option<DeviceType>,
pub fingerprint: String,
pub port: u16,
pub protocol: String,
#[serde(default)]
pub download: bool,
#[serde(default)]
pub announce: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Protocol {
Http,
Https,
}
impl Default for DeviceInfo {
fn default() -> Self {
Self {
alias: "RustSend".to_string(),
version: "2.1".to_string(),
device_model: None,
device_type: Some(DeviceType::Headless),
fingerprint: Julid::new().to_string(),
port: 53317,
protocol: "http".to_string(),
download: false,
announce: Some(true),
}
}
}
impl DeviceInfo {
pub fn to_json(&self) -> crate::error::Result<String> {
Ok(serde_json::to_string(self)?)
}
}

View file

@ -1,3 +0,0 @@
pub mod device;
pub mod file;
pub mod session;

View file

@ -1 +0,0 @@

View file

@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
JoecalState, JoecalState,
error::{LocalSendError, Result}, error::{LocalSendError, Result},
models::{device::DeviceInfo, file::FileMetadata}, models::{Device, FileMetadata},
}; };
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
@ -22,8 +22,8 @@ pub struct Session {
pub session_id: String, pub session_id: String,
pub files: HashMap<String, FileMetadata>, pub files: HashMap<String, FileMetadata>,
pub file_tokens: HashMap<String, String>, pub file_tokens: HashMap<String, String>,
pub receiver: DeviceInfo, pub receiver: Device,
pub sender: DeviceInfo, pub sender: Device,
pub status: SessionStatus, pub status: SessionStatus,
pub addr: SocketAddr, pub addr: SocketAddr,
} }
@ -47,7 +47,7 @@ pub struct PrepareUploadResponse {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PrepareUploadRequest { pub struct PrepareUploadRequest {
pub info: DeviceInfo, pub info: Device,
pub files: HashMap<String, FileMetadata>, pub files: HashMap<String, FileMetadata>,
} }