-
Notifications
You must be signed in to change notification settings - Fork 29
Adding PyTorch functionality to SIRF #1305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some detailed comments, but overall, I think we still need to decide on how to wrap AcquisitionModel
, AcquisitionSensitivityModel, Resampler
etc, i.e. anything that has forward
and backward
members. This should be written only once. You had https://github.com/SyneRBI/SIRF-Exercises/blob/71df538fe892ce621440544f14fa992c230fd120/notebooks/Deep_Learning_PET/sirf_torch.py#L38-L39, which seems completely generic, aside from naming of variables. So, something like this
class OperatorModule(torch.nn.Module):
def __init__(self, sirf_operator, sirf_src, sirf_dest):
"""constructs wrapper. WARNING operations will overwrite content of `sirf_src` and `sirf_dest`"""
...
with usage
acq_model = pet.AcquisitionModelParallelProj()
# etc etc
torch_acq_model = sirf.torch.OperatorModule(acq_model, image, acq_data)
Ideally, we also provide for being able to call save_for_backward
Added a README with the some notes for a design discussion here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to have too much code in the design discussion (duplication as well as distracting from actual design). You could include pointers to implementation in the README, but it's still small enough that this probably isn't worth the effort.
Just keep a skeleton in the README, the shorter the better...
How do you want people to react to your questions in the README? Via a review? (not conventional, but could work)
By the way, we should avoid running CI on this until it makes sense. Easiest is to put |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the test_cases need reviewed again. Perhaps grad_check is unneccessary since SIRF usually does tests for adjointness etc. We could suffice with simpler test checking the things are copied to and from torch/sirf correctly?
src/torch/tests/varnet_minimal.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gschramm this a minimal reproduction of your varnet. I am getting minuses in my projected input still. I think this is expected as the relu only ensures non-negativity in the forward pass.
Given the loss function there can be negative elements in the gradient even if the the relu is enforcing non-negativity in the forward pass.
@KrisThielemans this may mean that we require negative values in the input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to UCL/STIR#1477
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes of course, the gradient of the obj function will generally have positive and negative elements.
Fixing the STIR issue will take me a while though. Help welcome :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in that case we can't backpropagate through the gradient of the objective... right? We should remove the varnet example or through a warning/error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a comment saying that this need resolving the STIR issue, but leave it in, as it's instructive.
Co-authored-by: Kris Thielemans <[email protected]>
Hi I am happy this. I think for me the readme is now understandable and the user should know what dimensionality to give tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 more items:
- update CHANGES.md
- add GHA workflow. What are the requirements here?
(
$SIRF_PYTHON_EXECUTABLE`` -m pip install torch`?)
…ted as well as gradchecks tol decrease [ci skip]
else: | ||
raise TypeError(f"Unsupported SIRF object type: {type(sirf_src)}") | ||
|
||
def torch_to_sirf_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason this has an underscore suffix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggested by @ckolbPTB to follow torch conventions. I cannot remember why. something to do with autograd. Probably a good idea to document
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because there is an in-place operation that can be dangerous in PyTorch (breaks the computational graph). Trailing underscore indicates from what I understand!
- drop `object` typehints - enable torch cpu tests - misc formatting
@KrisThielemans, @casperdcl am I correct in thinking it's just CI left here, where we have a torch build and run the torch tests, then we can merge? Perhaps I can look at doing this at the hackathon? Also seems to be some error already with CI: https://github.com/SyneRBI/SIRF/actions/runs/14112476864/job/39534509274?pr=1305 doesn't seem to be related to changes here? |
Yes. As CI runs without Cuda, we need to have default for device, as indicated in the comments |
Is this PR behind master? @evgueni-ovtchinnikov are you aware of these failing tests? |
@KrisThielemans I noticed that it's good practice to add once_differentiable decorator so I have added that. Also it seemed like the PR was behind main, I have resynced. |
@Imraj-Singh @casperdcl @evgueni-ovtchinnikov Where are we with this? I see lots of errors
which don't occur in https://github.com/SyneRBI/SIRF/actions/runs/14908966112 for instance. No idea where they come frome. |
Changes in this pull request
The functionality:
torch.autograd.Function
for objective function, acquisition mode forward, and acquisition model backward. These are not really meant to be used by a user.torch.nn.Module
for a module that interacts with thetorch.autograd.Function
. In the forward method we could change the number of the dimensions in the array, swap in/out components of the forward operator etc.This functionality of
torch.nn.Module
can vary dependent on the user's requirements, so perhaps we should just give a use case. Using it for lpd for example, or PET reconstruction with torch optimisers?Testing performed
Gradchecks for all the new functions. Note that objective function evaluations are quite unstable.
Related issues
Checklist before requesting a review
Contribution Notes
Please read and adhere to the contribution guidelines.
Please tick the following: