From 99beca4dbef212e473e4b963f443c0372fb749c4 Mon Sep 17 00:00:00 2001 From: quexeky Date: Mon, 28 Oct 2024 22:06:44 +1100 Subject: [PATCH] Some progress on thread terminations --- src-tauri/src/downloads/download_agent.rs | 6 +++++- src-tauri/src/downloads/download_logic.rs | 24 +++++++++++++++-------- src-tauri/src/downloads/progress.rs | 23 +++++++++++++--------- src-tauri/src/tests/progress_tests.rs | 8 +++++--- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src-tauri/src/downloads/download_agent.rs b/src-tauri/src/downloads/download_agent.rs index 13b5bbb..b8fc0d5 100644 --- a/src-tauri/src/downloads/download_agent.rs +++ b/src-tauri/src/downloads/download_agent.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; use urlencoding::encode; use std::fs::{create_dir_all, File}; use std::path::Path; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::{AtomicBool, AtomicUsize}; use std::sync::{Arc, Mutex}; pub struct GameDownloadAgent { @@ -20,6 +20,7 @@ pub struct GameDownloadAgent { contexts: Mutex>, progress: ProgressChecker, pub manifest: Mutex>, + pub callback: Arc } #[derive(Serialize, Deserialize, Clone, Eq, PartialEq)] pub enum GameDownloadState { @@ -47,14 +48,17 @@ pub enum SystemError { impl GameDownloadAgent { pub fn new(id: String, version: String) -> Self { + let callback = Arc::new(AtomicBool::new(false)); Self { id, version, state: Mutex::from(GameDownloadState::Uninitialised), manifest: Mutex::new(None), + callback: callback.clone(), progress: ProgressChecker::new( Box::new(download_logic::download_game_chunk), Arc::new(AtomicUsize::new(0)), + callback ), contexts: Mutex::new(Vec::new()), } diff --git a/src-tauri/src/downloads/download_logic.rs b/src-tauri/src/downloads/download_logic.rs index 170ee71..7d3848b 100644 --- a/src-tauri/src/downloads/download_logic.rs +++ b/src-tauri/src/downloads/download_logic.rs @@ -5,18 +5,20 @@ use crate::DB; use gxhash::{gxhash128, GxHasher}; use log::info; use md5::{Context, Digest}; -use std::{fs::{File, OpenOptions}, hash::Hasher, io::{self, Seek, SeekFrom, Write}, path::PathBuf}; +use std::{fs::{File, OpenOptions}, hash::Hasher, io::{self, Seek, SeekFrom, Write}, path::PathBuf, sync::{atomic::{AtomicBool, Ordering}, Arc}}; use urlencoding::encode; -pub struct FileWriter { +pub struct DropFileWriter { file: File, hasher: Context, + callback: Arc } -impl FileWriter { - fn new(path: PathBuf) -> Self { +impl DropFileWriter { + fn new(path: PathBuf, callback: Arc) -> Self { Self { file: OpenOptions::new().write(true).open(path).unwrap(), hasher: Context::new(), + callback } } fn finish(mut self) -> io::Result { @@ -24,8 +26,11 @@ impl FileWriter { Ok(self.hasher.compute()) } } -impl Write for FileWriter { +impl Write for DropFileWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { + if self.callback.load(Ordering::Acquire) { + + } self.hasher.write_all(buf).unwrap(); self.file.write(buf) } @@ -35,12 +40,15 @@ impl Write for FileWriter { self.file.flush() } } -impl Seek for FileWriter { +impl Seek for DropFileWriter { fn seek(&mut self, pos: SeekFrom) -> std::io::Result { self.file.seek(pos) } } -pub fn download_game_chunk(ctx: DropDownloadContext) { +pub fn download_game_chunk(ctx: DropDownloadContext, callback: Arc) { + if callback.load(Ordering::Acquire) { + return; + } let base_url = DB.fetch_base_url(); let client = reqwest::blocking::Client::new(); @@ -63,7 +71,7 @@ pub fn download_game_chunk(ctx: DropDownloadContext) { .send() .unwrap(); - let mut file: FileWriter = FileWriter::new(ctx.path); + let mut file: DropFileWriter = DropFileWriter::new(ctx.path); if ctx.offset != 0 { file diff --git a/src-tauri/src/downloads/progress.rs b/src-tauri/src/downloads/progress.rs index 7e1c582..78d56ed 100644 --- a/src-tauri/src/downloads/progress.rs +++ b/src-tauri/src/downloads/progress.rs @@ -1,7 +1,7 @@ use rayon::ThreadPoolBuilder; use uuid::timestamp::context; use std::os::unix::thread; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; pub struct ProgressChecker @@ -9,7 +9,8 @@ where T: 'static + Send + Sync, { counter: Arc, - f: Arc>, + f: Arc) + Send + Sync + 'static>>, + callback: Arc } impl ProgressChecker @@ -17,24 +18,26 @@ where T: Send + Sync, { pub fn new( - f: Box, + f: Box) + Send + Sync + 'static>, counter_reference: Arc, + callback: Arc ) -> Self { Self { f: f.into(), counter: counter_reference, + callback } } pub async fn run_contexts_sequentially_async(&self, contexts: Vec) { for context in contexts { - (self.f)(context); - self.counter.fetch_add(1, Ordering::Relaxed); + (self.f)(context, self.callback.clone()); + self.counter.fetch_add(1, Ordering::Release); } } pub fn run_contexts_sequentially(&self, contexts: Vec) { for context in contexts { - (self.f)(context); - self.counter.fetch_add(1, Ordering::Relaxed); + (self.f)(context, self.callback.clone()); + self.counter.fetch_add(1, Ordering::Release); } } pub fn run_contexts_parallel_background(&self, contexts: Vec, max_threads: usize) { @@ -46,8 +49,9 @@ where .unwrap(); for context in contexts { + let callback = self.callback.clone(); let f = self.f.clone(); - threads.spawn(move || f(context)); + threads.spawn(move || f(context, callback)); } } pub async fn run_context_parallel(&self, contexts: Vec, max_threads: usize) { @@ -58,8 +62,9 @@ where threads.scope(|s| { for context in contexts { + let callback = self.callback.clone(); let f = self.f.clone(); - s.spawn(move |_| f(context)); + s.spawn(move |_| f(context, callback)); } }); diff --git a/src-tauri/src/tests/progress_tests.rs b/src-tauri/src/tests/progress_tests.rs index 95b1b0a..9c41b15 100644 --- a/src-tauri/src/tests/progress_tests.rs +++ b/src-tauri/src/tests/progress_tests.rs @@ -1,18 +1,20 @@ use std::sync::Arc; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::{AtomicBool, AtomicUsize}; use crate::downloads::progress::ProgressChecker; #[test] fn test_progress_sequentially() { let counter = Arc::new(AtomicUsize::new(0)); - let p = ProgressChecker::new(Box::new(test_fn), counter.clone()); + let callback = Arc::new(AtomicBool::new(false)); + let p = ProgressChecker::new(Box::new(test_fn), counter.clone(), callback); p.run_contexts_sequentially((1..100).collect()); println!("Progress: {}", p.get_progress_percentage(100)); } #[test] fn test_progress_parallel() { let counter = Arc::new(AtomicUsize::new(0)); - let p = ProgressChecker::new(Box::new(test_fn), counter.clone()); + let callback = Arc::new(AtomicBool::new(false)); + let p = ProgressChecker::new(Box::new(test_fn), counter.clone(), callback); p.run_contexts_parallel_background((1..100).collect(), 10); }