feat(downloads): Convert DownloadThreadControlFlag to AtomicBool

Also ran cargo fmt & cargo clipy

Signed-off-by: quexeky <git@quexeky.dev>
This commit is contained in:
quexeky
2024-11-11 09:39:25 +11:00
parent b47b7ea935
commit f25bfed336
7 changed files with 29 additions and 42 deletions

View File

@ -1,17 +1,13 @@
use std::{ use std::{
borrow::BorrowMut,
collections::HashMap, collections::HashMap,
fmt::format,
fs::{self, create_dir_all}, fs::{self, create_dir_all},
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::{LazyLock, Mutex}, sync::{LazyLock, Mutex},
}; };
use directories::BaseDirs; use directories::BaseDirs;
use log::info;
use rustbreak::{deser::Bincode, PathDatabase}; use rustbreak::{deser::Bincode, PathDatabase};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::fs::metadata;
use url::Url; use url::Url;
use crate::DB; use crate::DB;
@ -102,19 +98,19 @@ pub fn add_new_download_dir(new_dir: String) -> Result<(), String> {
if new_dir_path.exists() { if new_dir_path.exists() {
let metadata = new_dir_path let metadata = new_dir_path
.metadata() .metadata()
.map_err(|e| format!("Unable to access file or directory: {}", e.to_string()))?; .map_err(|e| format!("Unable to access file or directory: {}", e))?;
if !metadata.is_dir() { if !metadata.is_dir() {
return Err("Invalid path: not a directory".to_string()); return Err("Invalid path: not a directory".to_string());
} }
let dir_contents = new_dir_path let dir_contents = new_dir_path
.read_dir() .read_dir()
.map_err(|e| format!("Unable to check directory contents: {}", e.to_string()))?; .map_err(|e| format!("Unable to check directory contents: {}", e))?;
if dir_contents.count() == 0 { if dir_contents.count() == 0 {
return Err("Path is not empty".to_string()); return Err("Path is not empty".to_string());
} }
} else { } else {
create_dir_all(new_dir_path) create_dir_all(new_dir_path)
.map_err(|e| format!("Unable to create directories to path: {}", e.to_string()))?; .map_err(|e| format!("Unable to create directories to path: {}", e))?;
} }
// Add it to the dictionary // Add it to the dictionary

View File

@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::fs::{create_dir_all, File}; use std::fs::{create_dir_all, File};
use std::path::Path; use std::path::Path;
use std::sync::atomic::AtomicU64; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use urlencoding::encode; use urlencoding::encode;
@ -22,17 +22,17 @@ use super::download_logic::download_game_chunk;
pub struct GameDownloadAgent { pub struct GameDownloadAgent {
pub id: String, pub id: String,
pub version: String, pub version: String,
pub control_flag: Arc<RwLock<DownloadThreadControlFlag>>, pub control_flag: Arc<DownloadThreadControlFlag>,
pub target_download_dir: usize, pub target_download_dir: usize,
contexts: Mutex<Vec<DropDownloadContext>>, contexts: Mutex<Vec<DropDownloadContext>>,
pub manifest: Mutex<Option<DropManifest>>, pub manifest: Mutex<Option<DropManifest>>,
pub progress: ProgressObject, pub progress: ProgressObject,
} }
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq)]
pub enum DownloadThreadControlFlag { /// Faster alternative to a RwLock Enum.
Go, /// true = Go
Stop, /// false = Stop
} pub type DownloadThreadControlFlag = AtomicBool;
#[derive(Debug)] #[derive(Debug)]
pub enum GameDownloadError { pub enum GameDownloadError {
@ -61,7 +61,7 @@ pub struct ProgressObject {
impl GameDownloadAgent { impl GameDownloadAgent {
pub fn new(id: String, version: String, target_download_dir: usize) -> Self { pub fn new(id: String, version: String, target_download_dir: usize) -> Self {
// Don't run by default // Don't run by default
let status = Arc::new(RwLock::new(DownloadThreadControlFlag::Stop)); let status = Arc::new(DownloadThreadControlFlag::new(false));
Self { Self {
id, id,
version, version,
@ -75,13 +75,11 @@ impl GameDownloadAgent {
}, },
} }
} }
pub fn set_control_flag(&self, flag: DownloadThreadControlFlag) { pub fn set_control_flag(&self, flag: bool) {
let mut lock = self.control_flag.write().unwrap(); self.control_flag.store(flag, Ordering::Relaxed);
*lock = flag;
} }
pub fn get_control_flag(&self) -> DownloadThreadControlFlag { pub fn get_control_flag(&self) -> bool {
let lock = self.control_flag.read().unwrap(); self.control_flag.load(Ordering::Relaxed)
lock.clone()
} }
// Blocking // Blocking
@ -91,7 +89,7 @@ impl GameDownloadAgent {
self.generate_contexts()?; self.generate_contexts()?;
self.set_control_flag(DownloadThreadControlFlag::Go); self.set_control_flag(true);
Ok(()) Ok(())
} }
@ -110,7 +108,7 @@ impl GameDownloadAgent {
} }
// Explicitly propagate error // Explicitly propagate error
Ok(self.download_manifest()?) self.download_manifest()
} }
fn download_manifest(&mut self) -> Result<(), GameDownloadError> { fn download_manifest(&mut self) -> Result<(), GameDownloadError> {
@ -159,7 +157,7 @@ impl GameDownloadAgent {
return Ok(()); return Ok(());
} }
return Err(GameDownloadError::LockError); Err(GameDownloadError::LockError)
} }
pub fn generate_contexts(&self) -> Result<(), GameDownloadError> { pub fn generate_contexts(&self) -> Result<(), GameDownloadError> {
@ -210,9 +208,9 @@ impl GameDownloadAgent {
return Ok(()); return Ok(());
} }
return Err(GameDownloadError::SetupError( Err(GameDownloadError::SetupError(
"Failed to generate download contexts".to_owned(), "Failed to generate download contexts".to_owned(),
)); ))
} }
pub fn run(&self) { pub fn run(&self) {

View File

@ -1,7 +1,4 @@
use std::{ use std::sync::{Arc, Mutex};
borrow::Borrow,
sync::{Arc, Mutex},
};
use log::info; use log::info;
use rayon::spawn; use rayon::spawn;

View File

@ -3,15 +3,12 @@ use crate::db::DatabaseImpls;
use crate::downloads::manifest::DropDownloadContext; use crate::downloads::manifest::DropDownloadContext;
use crate::remote::RemoteAccessError; use crate::remote::RemoteAccessError;
use crate::DB; use crate::DB;
use atomic_counter::{AtomicCounter, RelaxedCounter};
use log::{error, info}; use log::{error, info};
use md5::{Context, Digest}; use md5::{Context, Digest};
use reqwest::blocking::Response; use reqwest::blocking::Response;
use serde::de::Error;
use std::io::Read; use std::io::Read;
use std::sync::atomic::AtomicU64; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use std::{ use std::{
fs::{File, OpenOptions}, fs::{File, OpenOptions},
io::{self, BufWriter, ErrorKind, Seek, SeekFrom, Write}, io::{self, BufWriter, ErrorKind, Seek, SeekFrom, Write},
@ -66,7 +63,7 @@ impl Seek for DropWriter<File> {
pub struct DropDownloadPipeline<R: Read, W: Write> { pub struct DropDownloadPipeline<R: Read, W: Write> {
pub source: R, pub source: R,
pub destination: DropWriter<W>, pub destination: DropWriter<W>,
pub control_flag: Arc<RwLock<DownloadThreadControlFlag>>, pub control_flag: Arc<DownloadThreadControlFlag>,
pub progress: Arc<AtomicU64>, pub progress: Arc<AtomicU64>,
pub size: usize, pub size: usize,
} }
@ -74,7 +71,7 @@ impl DropDownloadPipeline<Response, File> {
fn new( fn new(
source: Response, source: Response,
destination: DropWriter<File>, destination: DropWriter<File>,
control_flag: Arc<RwLock<DownloadThreadControlFlag>>, control_flag: Arc<DownloadThreadControlFlag>,
progress: Arc<AtomicU64>, progress: Arc<AtomicU64>,
size: usize, size: usize,
) -> Self { ) -> Self {
@ -94,7 +91,7 @@ impl DropDownloadPipeline<Response, File> {
let mut current_size = 0; let mut current_size = 0;
loop { loop {
if *self.control_flag.read().unwrap() == DownloadThreadControlFlag::Stop { if self.control_flag.load(Ordering::Relaxed) == false {
return Ok(false); return Ok(false);
} }
@ -123,11 +120,11 @@ impl DropDownloadPipeline<Response, File> {
pub fn download_game_chunk( pub fn download_game_chunk(
ctx: DropDownloadContext, ctx: DropDownloadContext,
control_flag: Arc<RwLock<DownloadThreadControlFlag>>, control_flag: Arc<DownloadThreadControlFlag>,
progress: Arc<AtomicU64>, progress: Arc<AtomicU64>,
) -> Result<bool, GameDownloadError> { ) -> Result<bool, GameDownloadError> {
// If we're paused // If we're paused
if *control_flag.read().unwrap() == DownloadThreadControlFlag::Stop { if control_flag.load(Ordering::Relaxed) {
return Ok(false); return Ok(false);
} }

View File

@ -1,4 +1,4 @@
pub mod download_agent; pub mod download_agent;
pub mod download_commands; pub mod download_commands;
mod download_logic; mod download_logic;
mod manifest; mod manifest;

View File

@ -20,7 +20,6 @@ use log::info;
use remote::{gen_drop_url, use_remote}; use remote::{gen_drop_url, use_remote};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use std::{ use std::{
collections::HashMap, collections::HashMap,
sync::{LazyLock, Mutex}, sync::{LazyLock, Mutex},

View File

@ -1 +1 @@
pub const DOWNLOAD_MAX_THREADS: usize = 4; pub const DOWNLOAD_MAX_THREADS: usize = 4;