Skip to content

shred: remove usage of Vec::set_len() #1738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 23, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions src/uu/shred/src/shred.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ struct BytesGenerator<'a> {
exact: bool, // if false, every block's size is block_size
gen_type: PassType<'a>,
rng: Option<RefCell<ThreadRng>>,
bytes: [u8; BLOCK_SIZE],
}

impl<'a> BytesGenerator<'a> {
Expand All @@ -128,32 +129,44 @@ impl<'a> BytesGenerator<'a> {
_ => None,
};

let bytes = [0; BLOCK_SIZE];

BytesGenerator {
total_bytes,
bytes_generated: Cell::new(0u64),
block_size: BLOCK_SIZE,
exact,
gen_type,
rng,
bytes,
}
}
}

impl<'a> Iterator for BytesGenerator<'a> {
type Item = Box<[u8]>;
pub fn reset(&mut self, total_bytes: u64, gen_type: PassType<'a>) {
if let PassType::Random = gen_type {
if self.rng.is_none() {
self.rng = Some(RefCell::new(rand::thread_rng()));
}
}

self.total_bytes = total_bytes;
self.gen_type = gen_type;

self.bytes_generated.set(0);
}

fn next(&mut self) -> Option<Box<[u8]>> {
pub fn next(&mut self) -> Option<&[u8]> {
// We go over the total_bytes limit when !self.exact and total_bytes isn't a multiple
// of self.block_size
if self.bytes_generated.get() >= self.total_bytes {
return None;
}

let this_block_size: usize = {
let this_block_size = {
if !self.exact {
self.block_size
} else {
let bytes_left: u64 = self.total_bytes - self.bytes_generated.get();
let bytes_left = self.total_bytes - self.bytes_generated.get();
if bytes_left >= self.block_size as u64 {
self.block_size
} else {
Expand All @@ -162,17 +175,12 @@ impl<'a> Iterator for BytesGenerator<'a> {
}
};

let mut bytes: Vec<u8> = Vec::with_capacity(this_block_size);
let bytes = &mut self.bytes[..this_block_size];

match self.gen_type {
PassType::Random => {
// This is ok because the vector was
// allocated with the same capacity
unsafe {
bytes.set_len(this_block_size);
}
let mut rng = self.rng.as_ref().unwrap().borrow_mut();
rng.fill(&mut bytes[..]);
rng.fill(bytes);
}
PassType::Pattern(pattern) => {
let skip = {
Expand All @@ -182,18 +190,25 @@ impl<'a> Iterator for BytesGenerator<'a> {
(pattern.len() as u64 % self.bytes_generated.get()) as usize
}
};
// Same range as 0..this_block_size but we start with the right index
for i in skip..this_block_size + skip {
let index = i % pattern.len();
bytes.push(pattern[index]);

// Copy the pattern in chunks rather than simply one byte at a time
let mut i = 0;
while i < this_block_size {
let start = (i + skip) % pattern.len();
let end = (this_block_size - i).min(pattern.len());
let len = end - start;

bytes[i..i + len].copy_from_slice(&pattern[start..end]);

i += len;
}
}
};

let new_bytes_generated = self.bytes_generated.get() + this_block_size as u64;
self.bytes_generated.set(new_bytes_generated);

Some(bytes.into_boxed_slice())
Some(bytes)
}
}

Expand Down Expand Up @@ -443,6 +458,10 @@ fn wipe_file(
.open(path)
.expect("Failed to open file for writing");

// NOTE: it does not really matter what we set for total_bytes and gen_type here, so just
// use bogus values
let mut generator = BytesGenerator::new(0, PassType::Pattern(&[]), exact);

for (i, pass_type) in pass_sequence.iter().enumerate() {
if verbose {
let pass_name: String = pass_name(*pass_type);
Expand All @@ -467,7 +486,8 @@ fn wipe_file(
}
}
// size is an optional argument for exactly how many bytes we want to shred
do_pass(&mut file, path, *pass_type, size, exact).expect("File write pass failed");
do_pass(&mut file, path, &mut generator, *pass_type, size)
.expect("File write pass failed");
// Ignore failed writes; just keep trying
}
}
Expand All @@ -477,22 +497,22 @@ fn wipe_file(
}
}

fn do_pass(
fn do_pass<'a>(
file: &mut File,
path: &Path,
generator_type: PassType,
generator: &mut BytesGenerator<'a>,
generator_type: PassType<'a>,
given_file_size: Option<u64>,
exact: bool,
) -> Result<(), io::Error> {
file.seek(SeekFrom::Start(0))?;

// Use the given size or the whole file if not specified
let size: u64 = given_file_size.unwrap_or(get_file_size(path)?);

let generator = BytesGenerator::new(size, generator_type, exact);
generator.reset(size, generator_type);

for block in generator {
file.write_all(&*block)?;
while let Some(block) = generator.next() {
file.write_all(block)?;
}

file.sync_data()?;
Expand Down