Skip to content

Adding PyTorch functionality to SIRF #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: master
Choose a base branch
from

Conversation

Imraj-Singh
Copy link
Contributor

@Imraj-Singh Imraj-Singh commented Feb 15, 2025

Changes in this pull request

The functionality:

  • Functions for transferring between torch and sirf, and vice versa.
  • torch.autograd.Function for objective function, acquisition mode forward, and acquisition model backward. These are not really meant to be used by a user.
  • torch.nn.Module for a module that interacts with the torch.autograd.Function. In the forward method we could change the number of the dimensions in the array, swap in/out components of the forward operator etc.

This functionality of torch.nn.Module can vary dependent on the user's requirements, so perhaps we should just give a use case. Using it for lpd for example, or PET reconstruction with torch optimisers?

Testing performed

Gradchecks for all the new functions. Note that objective function evaluations are quite unstable.

Related issues

Checklist before requesting a review

  • I have performed a self-review of my code
  • I have added docstrings/doxygen in line with the guidance in the developer guide
  • I have implemented unit tests that cover any new or modified functionality
  • The code builds and runs on my machine
  • CHANGES.md has been updated with any functionality change

Contribution Notes

Please read and adhere to the contribution guidelines.

Please tick the following:

  • The content of this Pull Request (the Contribution) is intentionally submitted for inclusion in SIRF (the Work) under the terms and conditions of the Apache-2.0 License.

@Imraj-Singh Imraj-Singh marked this pull request as draft February 15, 2025 13:03
@KrisThielemans KrisThielemans self-assigned this Feb 17, 2025
@KrisThielemans KrisThielemans linked an issue Feb 17, 2025 that may be closed by this pull request
Copy link
Member

@KrisThielemans KrisThielemans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some detailed comments, but overall, I think we still need to decide on how to wrap AcquisitionModel, AcquisitionSensitivityModel, Resampler etc, i.e. anything that has forward and backward members. This should be written only once. You had https://github.com/SyneRBI/SIRF-Exercises/blob/71df538fe892ce621440544f14fa992c230fd120/notebooks/Deep_Learning_PET/sirf_torch.py#L38-L39, which seems completely generic, aside from naming of variables. So, something like this

class OperatorModule(torch.nn.Module):
    def __init__(self, sirf_operator, sirf_src, sirf_dest):
    """constructs wrapper. WARNING operations will overwrite content of `sirf_src` and `sirf_dest`"""
    ...

with usage

acq_model = pet.AcquisitionModelParallelProj()
# etc etc
torch_acq_model = sirf.torch.OperatorModule(acq_model, image, acq_data)

Ideally, we also provide for being able to call save_for_backward

@Imraj-Singh
Copy link
Contributor Author

Added a README with the some notes for a design discussion here.

Copy link
Member

@KrisThielemans KrisThielemans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to have too much code in the design discussion (duplication as well as distracting from actual design). You could include pointers to implementation in the README, but it's still small enough that this probably isn't worth the effort.

Just keep a skeleton in the README, the shorter the better...

How do you want people to react to your questions in the README? Via a review? (not conventional, but could work)

@KrisThielemans
Copy link
Member

By the way, we should avoid running CI on this until it makes sense. Easiest is to put [ci skip] in the message of your last commit that you will push.

Copy link
Contributor Author

@Imraj-Singh Imraj-Singh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test_cases need reviewed again. Perhaps grad_check is unneccessary since SIRF usually does tests for adjointness etc. We could suffice with simpler test checking the things are copied to and from torch/sirf correctly?

Copy link
Contributor Author

@Imraj-Singh Imraj-Singh Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gschramm this a minimal reproduction of your varnet. I am getting minuses in my projected input still. I think this is expected as the relu only ensures non-negativity in the forward pass.

Given the loss function there can be negative elements in the gradient even if the the relu is enforcing non-negativity in the forward pass.

@KrisThielemans this may mean that we require negative values in the input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to UCL/STIR#1477

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes of course, the gradient of the obj function will generally have positive and negative elements.

Fixing the STIR issue will take me a while though. Help welcome :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case we can't backpropagate through the gradient of the objective... right? We should remove the varnet example or through a warning/error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment saying that this need resolving the STIR issue, but leave it in, as it's instructive.

@danieldeidda
Copy link
Collaborator

Hi I am happy this. I think for me the readme is now understandable and the user should know what dimensionality to give tensors

Copy link
Member

@KrisThielemans KrisThielemans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 more items:

  • update CHANGES.md
  • add GHA workflow. What are the requirements here? ($SIRF_PYTHON_EXECUTABLE`` -m pip install torch`?)

else:
raise TypeError(f"Unsupported SIRF object type: {type(sirf_src)}")

def torch_to_sirf_(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason this has an underscore suffix?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggested by @ckolbPTB to follow torch conventions. I cannot remember why. something to do with autograd. Probably a good idea to document

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because there is an in-place operation that can be dangerous in PyTorch (breaks the computational graph). Trailing underscore indicates from what I understand!

- drop `object` typehints
- enable torch cpu tests
- misc formatting
@Imraj-Singh
Copy link
Contributor Author

Imraj-Singh commented Apr 6, 2025

@KrisThielemans, @casperdcl am I correct in thinking it's just CI left here, where we have a torch build and run the torch tests, then we can merge? Perhaps I can look at doing this at the hackathon?

Also seems to be some error already with CI: https://github.com/SyneRBI/SIRF/actions/runs/14112476864/job/39534509274?pr=1305 doesn't seem to be related to changes here?

@KrisThielemans
Copy link
Member

Yes. As CI runs without Cuda, we need to have default for device, as indicated in the comments

@KrisThielemans KrisThielemans marked this pull request as ready for review April 6, 2025 11:31
@KrisThielemans
Copy link
Member

Is this PR behind master? @evgueni-ovtchinnikov are you aware of these failing tests?

@Imraj-Singh
Copy link
Contributor Author

@KrisThielemans I noticed that it's good practice to add once_differentiable decorator so I have added that.

Also it seemed like the PR was behind main, I have resynced.

@casperdcl casperdcl mentioned this pull request Apr 14, 2025
9 tasks
@KrisThielemans
Copy link
Member

@Imraj-Singh @casperdcl @evgueni-ovtchinnikov Where are we with this? I see lots of errors

 TypeError: float() argument must be a string or a real number, not '_NoValueType'

which don't occur in https://github.com/SyneRBI/SIRF/actions/runs/14908966112 for instance. No idea where they come frome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SIRF torch should be part of SIRF
7 participants