try to shutdown nicely

This commit is contained in:
Joe Ardent 2025-07-06 16:02:11 -07:00
parent f5af07f860
commit 342b634388
9 changed files with 164 additions and 138 deletions

View file

@ -9,7 +9,7 @@ use axum::{
};
use tokio::net::UdpSocket;
use crate::{Config, JoecalState, RunningState, models::device::DeviceInfo};
use crate::{Config, JoecalState, RunningState, models::Device};
impl JoecalState {
pub async fn announce(
@ -48,7 +48,7 @@ impl JoecalState {
}
async fn process_device(&self, message: &str, src: SocketAddr, config: &Config) {
if let Ok(device) = serde_json::from_str::<DeviceInfo>(message) {
if let Ok(device) = serde_json::from_str::<Device>(message) {
if device.fingerprint == self.device.fingerprint {
return;
}
@ -83,8 +83,8 @@ impl JoecalState {
pub async fn register_device(
State(state): State<JoecalState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(device): Json<DeviceInfo>,
) -> Json<DeviceInfo> {
Json(device): Json<Device>,
) -> Json<Device> {
let mut addr = addr;
addr.set_port(state.device.port);
state
@ -99,9 +99,9 @@ pub async fn register_device(
// private helpers
//-************************************************************************
async fn announce_http(
device: &DeviceInfo,
device: &Device,
ip: Option<SocketAddr>,
client: Arc<reqwest::Client>,
client: reqwest::Client,
) -> crate::error::Result<()> {
if let Some(ip) = ip {
let url = format!("http://{ip}/api/localsend/v2/register");
@ -111,7 +111,7 @@ async fn announce_http(
}
async fn announce_multicast(
device: &DeviceInfo,
device: &Device,
addr: SocketAddrV4,
socket: Arc<UdpSocket>,
) -> crate::error::Result<()> {

View file

@ -5,7 +5,7 @@ use axum::{
extract::DefaultBodyLimit,
routing::{get, post},
};
use tokio::net::TcpListener;
use tokio::{net::TcpListener, sync::mpsc};
use tower_http::limit::RequestBodyLimitLayer;
use crate::{
@ -15,8 +15,12 @@ use crate::{
};
impl JoecalState {
pub async fn start_http_server(&self, config: &Config) -> crate::error::Result<()> {
let app = self.create_router(&config);
pub async fn start_http_server(
&self,
stop_rx: mpsc::Receiver<()>,
config: &Config,
) -> crate::error::Result<()> {
let app = self.create_router(config);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
let listener = TcpListener::bind(&addr).await?;
@ -26,6 +30,7 @@ impl JoecalState {
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown(stop_rx))
.await?;
Ok(())
}
@ -52,3 +57,8 @@ impl JoecalState {
.with_state(self.clone())
}
}
async fn shutdown(mut rx: mpsc::Receiver<()>) {
println!("shutting down");
rx.recv().await.unwrap_or_default()
}

View file

@ -8,11 +8,16 @@ use std::{
collections::HashMap,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
time::Duration,
};
use models::device::DeviceInfo;
use models::Device;
use serde::{Deserialize, Serialize};
use tokio::{net::UdpSocket, sync::Mutex, task::JoinHandle};
use tokio::{
net::UdpSocket,
sync::{Mutex, mpsc},
task::JoinHandle,
};
use transfer::Session;
pub const DEFAULT_PORT: u16 = 53317;
@ -20,18 +25,20 @@ pub const MULTICAST_IP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 167);
pub const LISTENING_SOCKET_ADDR: SocketAddrV4 =
SocketAddrV4::new(Ipv4Addr::from_bits(0), DEFAULT_PORT);
/// Contains the main network and application state for an application session.
#[derive(Clone)]
pub struct JoecalState {
pub device: DeviceInfo,
pub peers: Arc<Mutex<HashMap<String, (SocketAddr, DeviceInfo)>>>,
pub device: Device,
pub peers: Arc<Mutex<HashMap<String, (SocketAddr, Device)>>>,
pub sessions: Arc<Mutex<HashMap<String, Session>>>, // Session ID to Session
pub running_state: Arc<Mutex<RunningState>>,
pub socket: Arc<UdpSocket>,
pub client: Arc<reqwest::Client>,
pub client: reqwest::Client,
stop_tx: std::sync::OnceLock<mpsc::Sender<()>>,
}
impl JoecalState {
pub async fn new(device: DeviceInfo) -> crate::error::Result<Self> {
pub async fn new(device: Device) -> crate::error::Result<Self> {
let socket = UdpSocket::bind(LISTENING_SOCKET_ADDR).await?;
socket.set_multicast_loop_v4(true)?;
socket.set_multicast_ttl_v4(2)?; // one hop out from localnet
@ -43,45 +50,11 @@ impl JoecalState {
sessions: Default::default(),
running_state: Default::default(),
socket: socket.into(),
client: reqwest::Client::new().into(),
client: reqwest::Client::new(),
stop_tx: Default::default(),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RunningState {
Running,
Sending,
Receiving,
Stopping,
}
impl Default for RunningState {
fn default() -> Self {
Self::Running
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub multicast_addr: SocketAddrV4,
pub port: u16,
pub download_dir: String,
}
impl Default for Config {
fn default() -> Self {
let home = std::env::home_dir().unwrap_or("/tmp".into());
let dd = home.join("joecalsend-downloads");
Self {
multicast_addr: SocketAddrV4::new(MULTICAST_IP, DEFAULT_PORT),
port: DEFAULT_PORT,
download_dir: dd.to_string_lossy().into(),
}
}
}
impl JoecalState {
pub async fn start(
&self,
config: &Config,
@ -89,8 +62,10 @@ impl JoecalState {
let state = self.clone();
let konfig = config.clone();
let server_handle = {
let (tx, rx) = mpsc::channel(1);
self.stop_tx.get_or_init(|| tx);
tokio::spawn(async move {
if let Err(e) = state.start_http_server(&konfig).await {
if let Err(e) = state.start_http_server(rx, &konfig).await {
eprintln!("HTTP server error: {e}");
}
})
@ -126,8 +101,59 @@ impl JoecalState {
Ok((server_handle, udp_handle, announcement_handle))
}
pub async fn stop(&self) {
loop {
if let Ok(mut lock) = self.running_state.try_lock() {
*lock = RunningState::Stopping;
if self
.stop_tx
.get()
.expect("Could not get stop signal transmitter")
.send(())
.await
.is_ok()
{
break;
}
}
}
}
pub async fn refresh_peers(&self) {
let mut peers = self.peers.lock().await;
peers.clear();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RunningState {
Running,
Sending,
Receiving,
Stopping,
}
impl Default for RunningState {
fn default() -> Self {
Self::Running
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub multicast_addr: SocketAddrV4,
pub port: u16,
pub download_dir: String,
}
impl Default for Config {
fn default() -> Self {
let home = std::env::home_dir().unwrap_or("/tmp".into());
let dd = home.join("joecalsend-downloads");
Self {
multicast_addr: SocketAddrV4::new(MULTICAST_IP, DEFAULT_PORT),
port: DEFAULT_PORT,
download_dir: dd.to_string_lossy().into(),
}
}
}

View file

@ -1,6 +1,6 @@
use std::io;
use joecalsend::{Config, JoecalState, RunningState, error, models::device::DeviceInfo};
use joecalsend::{Config, JoecalState, RunningState, error, models::Device};
use local_ip_address::local_ip;
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig, V4IfAddr};
use ratatui::{
@ -16,7 +16,7 @@ use ratatui::{
#[tokio::main]
async fn main() -> error::Result<()> {
let device = DeviceInfo::default();
let device = Device::default();
dbg!(&device);
let std::net::IpAddr::V4(ip) = local_ip()? else {
@ -48,9 +48,9 @@ async fn main() -> error::Result<()> {
let mut app = App::new(state.clone());
let mut terminal = ratatui::init();
let result = app.run(&mut terminal);
let result = app.run(&mut terminal).await;
ratatui::restore();
//let _ = tokio::join!(h1, h2, h3);
let _ = tokio::join!(h1, h2, h3);
Ok(result?)
}
@ -64,10 +64,10 @@ impl App {
App { state }
}
pub fn run(&mut self, terminal: &mut DefaultTerminal) -> io::Result<()> {
pub async fn run(&mut self, terminal: &mut DefaultTerminal) -> io::Result<()> {
loop {
terminal.draw(|frame| self.draw(frame))?;
self.handle_events()?;
self.handle_events().await?;
if let Ok(lock) = self.state.running_state.try_lock()
&& *lock == RunningState::Stopping
{
@ -81,21 +81,21 @@ impl App {
frame.render_widget(self, frame.area());
}
fn handle_events(&mut self) -> io::Result<()> {
async fn handle_events(&mut self) -> io::Result<()> {
match event::read()? {
// it's important to check that the event is a key press event as
// crossterm also emits key release and repeat events on Windows.
Event::Key(key_event) if key_event.kind == KeyEventKind::Press => {
self.handle_key_event(key_event)
self.handle_key_event(key_event).await
}
_ => {}
};
Ok(())
}
fn handle_key_event(&mut self, key_event: KeyEvent) {
async fn handle_key_event(&mut self, key_event: KeyEvent) {
match key_event.code {
KeyCode::Char('q') => self.exit(),
KeyCode::Char('q') => self.exit().await,
KeyCode::Char('s') => {}
KeyCode::Char('r') => {}
KeyCode::Char('d') => {}
@ -103,13 +103,8 @@ impl App {
}
}
fn exit(&mut self) {
loop {
if let Ok(mut lock) = self.state.running_state.try_lock() {
*lock = RunningState::Stopping;
break;
}
}
async fn exit(&self) {
self.state.stop().await;
}
}

View file

@ -1,6 +1,7 @@
use std::{path::Path, time::SystemTime};
use chrono::{DateTime, Utc};
use julid::Julid;
use serde::{Deserialize, Serialize};
use crate::error::LocalSendError;
@ -63,6 +64,64 @@ impl FileMetadata {
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum DeviceType {
Mobile,
Desktop,
Web,
Headless,
Server,
Unknown,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Device {
pub alias: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_type: Option<DeviceType>,
pub fingerprint: String,
pub port: u16,
pub protocol: String,
#[serde(default)]
pub download: bool,
#[serde(default)]
pub announce: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Protocol {
Http,
Https,
}
impl Default for Device {
fn default() -> Self {
Self {
alias: "RustSend".to_string(),
version: "2.1".to_string(),
device_model: None,
device_type: Some(DeviceType::Headless),
fingerprint: Julid::new().to_string(),
port: 53317,
protocol: "http".to_string(),
download: false,
announce: Some(true),
}
}
}
impl Device {
pub fn to_json(&self) -> crate::error::Result<String> {
Ok(serde_json::to_string(self)?)
}
}
fn format_datetime(system_time: SystemTime) -> String {
let datetime: DateTime<Utc> = system_time.into();
datetime.to_rfc3339()

View file

@ -1,60 +0,0 @@
use julid::Julid;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum DeviceType {
Mobile,
Desktop,
Web,
Headless,
Server,
Unknown,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct DeviceInfo {
pub alias: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_type: Option<DeviceType>,
pub fingerprint: String,
pub port: u16,
pub protocol: String,
#[serde(default)]
pub download: bool,
#[serde(default)]
pub announce: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Protocol {
Http,
Https,
}
impl Default for DeviceInfo {
fn default() -> Self {
Self {
alias: "RustSend".to_string(),
version: "2.1".to_string(),
device_model: None,
device_type: Some(DeviceType::Headless),
fingerprint: Julid::new().to_string(),
port: 53317,
protocol: "http".to_string(),
download: false,
announce: Some(true),
}
}
}
impl DeviceInfo {
pub fn to_json(&self) -> crate::error::Result<String> {
Ok(serde_json::to_string(self)?)
}
}

View file

@ -1,3 +0,0 @@
pub mod device;
pub mod file;
pub mod session;

View file

@ -1 +0,0 @@

View file

@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
use crate::{
JoecalState,
error::{LocalSendError, Result},
models::{device::DeviceInfo, file::FileMetadata},
models::{Device, FileMetadata},
};
#[derive(Deserialize, Serialize)]
@ -22,8 +22,8 @@ pub struct Session {
pub session_id: String,
pub files: HashMap<String, FileMetadata>,
pub file_tokens: HashMap<String, String>,
pub receiver: DeviceInfo,
pub sender: DeviceInfo,
pub receiver: Device,
pub sender: Device,
pub status: SessionStatus,
pub addr: SocketAddr,
}
@ -47,7 +47,7 @@ pub struct PrepareUploadResponse {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PrepareUploadRequest {
pub info: DeviceInfo,
pub info: Device,
pub files: HashMap<String, FileMetadata>,
}