Skip to content

Grad.get leads to panic due to unwrap() instead of returning the None option #2924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
VirtualNonsense opened this issue Mar 19, 2025 · 9 comments
Labels
enhancement Enhance existing features

Comments

@VirtualNonsense
Copy link

Hey everyone!
I'm currently trying to visualize gradients and trying to access the inner workings of gradients.
During this I noticed that accessing an ID that is not present within the instance of GradientsParams leads to a panic.

The issue for this seems to lie in this line

If i understand this section correctly something like this should do the trick

    /// Get a tensor with the given ID.
    pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
    where
        B: Backend,
    {
        let grad = self.tensors.get(id)?;

        match grad.downcast_ref::<TensorPrimitive<B>>(){
            Some(tensor) => Some(tensor.clone())
            None => None
        }
    }
@laggui
Copy link
Member

laggui commented Mar 19, 2025

During this I noticed that accessing an ID that is not present within the instance of GradientsParams leads to a panic.

Hmm I'm not sure this is the cause of your issue. If the ID is not present, then self.tensors.get(id)? would propagate None to the calling function. So the downcast should not happen on None.

The downcast would only fail if the ID exists in the map but the stored value is not a tensor primitive.

Where/how does this happen in your use case?

VirtualNonsense added a commit to VirtualNonsense/sketchy_pix2pix that referenced this issue Mar 19, 2025
@VirtualNonsense
Copy link
Author

VirtualNonsense commented Mar 19, 2025

First of all: thanks for the response!

Hmm I'm not sure this is the cause of your issue. If the ID is not present, then self.tensors.get(id)? would propagate None to the calling function. So the downcast should not happen on None.

Yeah of course.
Sorry maybe I was a bit quick to jump to a conclusion 😬
The issue seems to be the down casting operation bellow self.tensor.get(id)?.

first of all here is the full stack backtrace:

cargo run --package sketchy_pix2pix --bin train_pix2pix --release
    Finished `release` profile [optimized] target(s) in 1.54s
     Running `target\release\train_pix2pix.exe`
[00:00:00] ###########-----------------------------       1/4       Epochs
[00:00:00] ----------------------------------------       0/61262   Training:                                                                          
thread 'main' panicked at C:\Users\anachtmann\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\burn-tensor-0.16.0\src\tensor\container.rs:48:14:
called `Option::unwrap()` on a `None` value
stack backtrace:
   0:     0x7ff65b2cbd73 - <std::sys::backtrace::BacktraceLock::print::DisplayBacktrace as core::fmt::Display>::fmt::h985b9d5f3de9f12f
   1:     0x7ff65b319b3a - core::fmt::write::h9a96aa6da4d5ad38
   2:     0x7ff65b2c1762 - std::io::Write::write_fmt::h74a921bbf668ad6e
   3:     0x7ff65b2cbbe5 - std::sys::backtrace::BacktraceLock::print::h774267aea0893176
   4:     0x7ff65b2ce531 - std::panicking::default_hook::{{closure}}::h6eab04a97e765ba6
   5:     0x7ff65b2ce2f4 - std::panicking::default_hook::h9064ebb8ce389b9d
   6:     0x7ff65b2ceea7 - std::panicking::rust_panic_with_hook::h38478ad21fdfd52f
   7:     0x7ff65b2cebfb - std::panicking::begin_panic_handler::{{closure}}::h872718a2180c2512
   8:     0x7ff65b2cc5df - std::sys::backtrace::__rust_end_short_backtrace::h9a03fe215a9e4fe2
   9:     0x7ff65b2ce8ae - rust_begin_unwind
  10:     0x7ff65b315b51 - core::panicking::panic_fmt::hc670a2896bf4a9e4
  11:     0x7ff65b315bed - core::panicking::panic::h364ba447b524b63e
  12:     0x7ff65b3159be - core::option::unwrap_failed::h1526b6bcec421e86
  13:     0x7ff658c3909d - burn_tensor::tensor::container::TensorContainer<ID>::get::h2257f4f9651f6ad6
  14:     0x7ff658939b6c - sketchy_pix2pix::pix2pix::gan::train_gan::h6983b010f23efafa
  15:     0x7ff65896bac3 - train_pix2pix::main::hc9b04c07bf7b975a
  16:     0x7ff658a90f46 - std::sys::backtrace::__rust_begin_short_backtrace::h46742e96e7401c4d
  17:     0x7ff658b2b3cc - std::rt::lang_start::{{closure}}::hb715ca1031083444
  18:     0x7ff65b2b2e35 - std::rt::lang_start_internal::hadc476be3f8121d5
  19:     0x7ff65896bc1d - main
  20:     0x7ff6585b1340 - __tmainCRTStartup
  21:     0x7ff6585b1146 - mainCRTStartup
  22:     0x7ffcb495e8d7 - <unknown>
  23:     0x7ffcb53bbf6c - <unknown>
error: process didn't exit successfully: `target\release\train_pix2pix.exe` (exit code: 101)

Here is what i'm trying to do.
The TL;DR is:

let grad_option: Option<Tensor<B, 4>> = grad.get((self.conv1).weight.id);

is causing the panic above.

@laggui
Copy link
Member

laggui commented Mar 19, 2025

Ahhhh I see the problem.

The tensor container does not have the gradients on the autodiff backend, but on the inner backend.

let grad_option: Option<Tensor<B::InnerBackend, 4>> = grad.get((self.conv1).weight.id);

Your log_graph implementation is for any B: Backend, but you're calling it from your model on the autodiff backend (in the training loop).

So the implementation should be:

impl<B: AutodiffBackend> Pix2PixDiscriminator<B> {
    pub fn log_graph(&self, grad: &GradientsParams) {
        let grad_option: Option<Tensor<B::InnerBackend, 4>> = grad.get((self.conv1).weight.id);
        if let Some(_grad) = grad_option {
            let labels = ["channels_out", "channels_in", "width", "height"];

            if let Ok(c) = LogContainer::from_burn_tensorf32(grad, labels.clone()) {
                let _ = c.log_to_stream(&log, "discriminator/grad/");
            }
        }
    }
}

The .unwrap() after downcast fails because the id exists, but it's not on the same backend. In other words (or should I say, in code)

grad.type_id() != TypeId::of::<TensorPrimitive<B>>()

Lmk if that solves your issue!

@VirtualNonsense
Copy link
Author

VirtualNonsense commented Mar 24, 2025

Sorry for the late reply!
It works like a charm!
Thank you!

How would you like to proceed?
Should I leave this issue open, because it is a little confusing if the program panics because of the backend mismatch (or rather the unwrap on the None variant)?
As a user I would preferer some result type as return value to communicate the reason for a failure more clearly.

@laggui
Copy link
Member

laggui commented Mar 24, 2025

I would leave this issue open, I agree that the error should be communicated more clearly.

Probably a Result where an invalid return can only be for a DowncastMismatch at the TensorContainer level, and at the GradientsParams level we could have an explicit error message for this scenario, e.g.:

.expect("Downcast mismatch when retrieving tensor. If you are trying to retrieve the gradients for a given parameter id, make sure to use the inner backend. Gradients are not stored on the autodiff backend.")

Or propagate the result at this level, but we can't really give a hint in this case. Since you've encountered this issue specifically, what do you think would have been most helpful?

@VirtualNonsense
Copy link
Author

For my use case your proposed solution would be fine.

However originally I was thinking about something along the lines of this to avoid panicking the program:

/// Error type for tensor container operations.
#[derive(Debug)]
pub enum TensorContainerError {
    /// The tensor with the given ID was not found.
    KeyError,
    /// Downcast mismatch when retrieving tensor. 
    /// If you are trying to retrieve the gradients for a given parameter id, make sure to use the inner backend.
    /// Gradients are not stored on the autodiff backend.
    DowncastError,
}
impl<ID> TensorContainer<ID>
where
    ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
{
// .....
    /// Get a tensor with the given ID.
    pub fn get<B>(&self, id: &ID) -> Result<TensorPrimitive<B>, TensorContainerError>
    where
        B: Backend,
    {
        match self.tensors.get(id){
            Some(grad) => match grad.downcast_ref::<TensorPrimitive<B>>() {
                Some(tensor) => Ok(tensor.clone()),
                None => Err(TensorContainerError::DowncastError),
            },
            None => Err(TensorContainerError::KeyError),
        }
    }

// ....

This could be applied to the remove method as well since it seems to do the same thing:

// ....
    /// Remove a tensor for the given ID and returns it.
    pub fn remove<B>(&mut self, id: &ID) -> Result<TensorPrimitive<B>, TensorContainerError>
    where
        B: Backend,
    {
        // self.tensors
        //     .remove(id)
        //     .map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())

        match self.tensors.remove(id) {
            Some(tensor) => match tensor.downcast::<TensorPrimitive<B>>() {
                Ok(tensor) => Ok(*tensor),
                Err(_uncast) => Err(TensorContainerError::DowncastError),
            },
            None => Err(TensorContainerError::KeyError),
        }
    }

// ....
}

to stay consistent.

But this may be beyond the scope of this issue.
Also even though the tests within burn-tensor seem to be fine I do not know if this breaks anything down the line.

@laggui laggui added the enhancement Enhance existing features label Mar 25, 2025
@laggui
Copy link
Member

laggui commented Mar 25, 2025

If the param id is not in the container, I still think it makes sense to return None. This behavior is in line with other standard containers like hash maps.

But propagating the error to the user makes sense to avoid panicking the program at this stage.

@VirtualNonsense
Copy link
Author

Hm makes sense.
Do you prefer the option type with the improved error message or should I create a pull request using the error enum and rename KeyError to None

@laggui
Copy link
Member

laggui commented Mar 27, 2025

If the param id is not in the container, I still think it makes sense to return None

Hmm right, after thinking about it it wouldn't be as ergonomic to return a Result<Option<TensorPrimitive<B>>, ErrorType> (with a better enum name for the error type obviously ).

I think we can go with your second suggestion, but I wouldn't name the variant None 🤔 Maybe NotFound, and for the failed downcast either DowncastError or TypeMismatch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

No branches or pull requests

2 participants