Skip to content

Adding PyTorch functionality to SIRF #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
789f19a
Perhaps the wrapper here?
Imraj-Singh Feb 15, 2025
e7659e9
Prototyping a solution to the problem of using a dataset with an acqu…
Imraj-Singh Feb 17, 2025
cd0c207
Adding a small test and cleaning things up, getting a segfault with t…
Imraj-Singh Feb 18, 2025
b97ddde
adding a readme, and starting to test with Gadgetron (to no avail)
Imraj-Singh Feb 18, 2025
9df22d5
Update README.md
Imraj-Singh Feb 18, 2025
e4ce261
Small changes to be tested [ci skip]
Imraj-Singh Feb 19, 2025
ca50798
Some more changes to mainly AcquisitionModelForward [ci skip]
Imraj-Singh Feb 19, 2025
0380fa5
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torc…
Imraj-Singh Feb 19, 2025
9866a33
first test optimising pet objectives in torch, and using the torch pn…
Imraj-Singh Feb 21, 2025
f81b974
objective function working with pet [ci skip]
Imraj-Singh Feb 21, 2025
01084cc
removing some redundancy
Imraj-Singh Feb 21, 2025
9b130b2
update name [ci skip]
Imraj-Singh Feb 21, 2025
b8cbb64
acq_mdl -> operator, removing kwargs, trailing _ for inplace
Imraj-Singh Feb 26, 2025
dc20472
tidying and making more generic
Imraj-Singh Feb 26, 2025
b0bd284
updated read me [ci skip]
Imraj-Singh Feb 26, 2025
409544f
Objective gradient scaffold [ci skip]
Imraj-Singh Feb 26, 2025
9c31b2d
adding better way of specifying the jacobian adjoint [ci skip]
Imraj-Singh Feb 27, 2025
33df7ef
Updating naming, added adjoint operator to sirf as discussed in meeti…
Imraj-Singh Feb 27, 2025
4f67188
updating cmake removing duplicate [ci skip]
Imraj-Singh Feb 27, 2025
fe75bfe
Circular import nonsense
Imraj-Singh Feb 27, 2025
7930ac6
adjoint method doesn't check if linear at the moment, adding that to …
Imraj-Singh Feb 28, 2025
7378a47
sirf torch error checking updates
Imraj-Singh Feb 28, 2025
8c4f110
starting some gradchecks, failing with sirf obj func
Imraj-Singh Feb 28, 2025
2d25953
nondet_tol needs to be high for obj func for some reason... [ci skip]
Imraj-Singh Feb 28, 2025
e7e2dc8
testing obj func grad
Imraj-Singh Mar 2, 2025
60bd4b5
refactoring to abstract common functionality
Imraj-Singh Mar 2, 2025
f72a561
added more gradchecks for hessian and also wrapped acquisition model …
Imraj-Singh Mar 2, 2025
9aca992
3D doesn't work as nondet tol too low, issues with PNLL numerical Jac…
Imraj-Singh Mar 2, 2025
2b9cea6
Slight change in naming, more abstracting
Imraj-Singh Mar 2, 2025
d097e17
Use cases nearly done [ci skip]
Imraj-Singh Mar 3, 2025
70d6e3f
Apply suggestions from code review
Imraj-Singh Mar 3, 2025
262e11f
added documentation and pet use cases
Imraj-Singh Mar 3, 2025
2c68b87
trying some pytests [ci skip]
Imraj-Singh Mar 3, 2025
70d147b
removing clamping [ci skip]
Imraj-Singh Mar 3, 2025
dbe05e2
Making the tols even more lax for obj_fun... [ci skip]
Imraj-Singh Mar 3, 2025
cdde96f
updated readme mainly [ci skip]
Imraj-Singh Mar 4, 2025
b0ab90b
Apply suggestions from Kris' code review
Imraj-Singh Mar 6, 2025
21cdc4f
added a todo
Imraj-Singh Mar 6, 2025
26bd325
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Mar 6, 2025
576fbbc
readme changes [ci skip]
Imraj-Singh Mar 6, 2025
c7a359a
Update src/torch/SIRF_torch.py [ci skip]
Imraj-Singh Mar 6, 2025
3220514
removing test cases [ci skip]
Imraj-Singh Mar 6, 2025
31de795
more readme changes [ci skip]
Imraj-Singh Mar 7, 2025
909b1de
MR grad checks working, getting sinister segfault for obj func [ci skip]
Imraj-Singh Mar 7, 2025
142a338
changing the init to allow for sirf.blah [ci skip]
Imraj-Singh Mar 7, 2025
97e2439
calling for value of obj_func and removing negative [ci skip]
Imraj-Singh Mar 7, 2025
b2b603d
updating the README to explain dimensionality and remove negativity r…
Imraj-Singh Mar 7, 2025
69470b5
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Mar 7, 2025
a788021
trying (probably failing) to make the imports better using that __all…
Imraj-Singh Mar 7, 2025
d1fa40b
updating the naming to be less verbose [ci skip]
Imraj-Singh Mar 7, 2025
0dd3770
moving gradcheck and use_cases so not to get circular imports... [ci …
Imraj-Singh Mar 7, 2025
1049818
minimal varnet [ci skip]
Imraj-Singh Mar 7, 2025
a6f2282
Apply suggestions from code review [ci skip]
Imraj-Singh Mar 9, 2025
bef9f19
take away nonsense from __init__ and update README [ci skip]
Imraj-Singh Mar 9, 2025
15179aa
Updating use cases [ci skip]
Imraj-Singh Mar 12, 2025
1893fa6
requires_grad=False for tensors returned by backward, changes.md upda…
Imraj-Singh Mar 13, 2025
362875b
add image to gd use_case [ci skip]
Imraj-Singh Mar 13, 2025
910fa2c
misc review fixups
casperdcl Mar 27, 2025
0a41e5b
Add once_differentiable decorator [ci skip]
Imraj-Singh Apr 6, 2025
7d27c80
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Apr 6, 2025
0c91764
codacity changes [ci skip]
Imraj-Singh Apr 6, 2025
c0adc9c
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Apr 6, 2025
a7348ce
codacy changes [ci skip]
Imraj-Singh Apr 6, 2025
89204a8
run ci things
Imraj-Singh Apr 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
- `ScatterEstimation` has extra methods that allow setting masks for the tail-fitting
- `ImageData` has extra method to zoom image using information from a template image, `zoom_image_as_template`.
- Error raised in `AcquisitionSensitivityModel.[un]normalise` methods applied to a read-only object.
- Error raised if `AcquisitionModel.adjoint` ran when the model is not linear.
* SIRF-torch
- `torch/torch.py` has wrappers for pytorch objective functions, objective function gradient and operators
- `torch/tests/gradchecks.py` has gradchecks for the wrappers 2d/3d PET and 2d MRI.
- `torch/tests/use_cases.py` has use cases for 2d PET using all the wrappers.
- `torch/README.md` includes user directions for the wrappers.
- `torch/CMakeList.txt` installation of sirf.torch
- `src/CMakeList.txt` installation of sirf.torch
* SIRF
- `cmake/sirf.__init__.py.in` import sirf.SIRF content into the `sirf` namespace for convenience
- `common/SIRF.py` adding adjoint operator


## v3.8.1
Expand Down
5 changes: 4 additions & 1 deletion cmake/sirf.__init__.py.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
__version_major__ = '@VERSION_MAJOR@'
__version_minor__ = '@VERSION_MINOR@'
__version_patch__ = '@VERSION_PATCH@'
__version__ = '@VERSION_MAJOR@.@VERSION_MINOR@.@VERSION_PATCH@'
__version__ = '@VERSION_MAJOR@.@VERSION_MINOR@.@VERSION_PATCH@'

# import sirf.SIRF content into the `sirf` namespace for convenience
from sirf.SIRF import *
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ else()
endif()

ADD_SUBDIRECTORY(common)
ADD_SUBDIRECTORY(torch)
44 changes: 31 additions & 13 deletions src/common/SIRF.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
'''
Object-Oriented wrap for the cSIRF-to-Python interface pysirf.py
'''

'''Object-Oriented wrap for the cSIRF-to-Python interface pysirf.py'''
## SyneRBI Synergistic Image Reconstruction Framework (SIRF)
## Copyright 2015 - 2020 Rutherford Appleton Laboratory STFC
## Copyright 2015 - 2020 University College London
Expand Down Expand Up @@ -43,6 +40,10 @@
else:
ABC = abc.ABCMeta('ABC', (), {})

# In future, would be good to explicitly list all objects to import when doing `from sirf.SIRF import *`.
# However, we will keep this for later to avoid mistakes in updating this variable.
# __all__ = ['DataContainer', 'ImageData', 'GeometricalInfo', 'AdjointOperator']


class DataContainer(ABC):
'''
Expand All @@ -63,7 +64,7 @@ def __add__(self, other):
'''
Overloads + for data containers.

Returns the sum of the container data with another container
Returns the sum of the container data with another container
data viewed as vectors.
other: DataContainer
'''
Expand All @@ -73,7 +74,7 @@ def __sub__(self, other):
'''
Overloads - for data containers.

Returns the difference of the container data with another container
Returns the difference of the container data with another container
data viewed as vectors.
other: DataContainer
'''
Expand Down Expand Up @@ -251,7 +252,7 @@ def squared_norm(self):

def dot(self, other):
'''
Returns the dot product of the container data with another container
Returns the dot product of the container data with another container
data viewed as vectors.
other: DataContainer
'''
Expand Down Expand Up @@ -417,10 +418,10 @@ def axpby(self, a, b, y, out=None, **kwargs):
'''
Linear combination for data containers.

Returns the linear combination of the self data with another container
Returns the linear combination of the self data with another container
data y viewed as vectors.
a: multiplier to self, can be a number or a DataContainer
b: multiplier to y, can be a number or a DataContainer
b: multiplier to y, can be a number or a DataContainer
y: DataContainer
out: DataContainer to store the result to.
'''
Expand All @@ -430,10 +431,10 @@ def sapyb(self, a, y, b, out=None, **kwargs):
'''
Linear combination for data containers: new interface.

Returns the linear combination of the self data with another container
Returns the linear combination of the self data with another container
data y viewed as vectors.
a: multiplier to self, can be a number or a DataContainer
b: multiplier to y, can be a number or a DataContainer
b: multiplier to y, can be a number or a DataContainer
y: DataContainer
out: DataContainer to store the result to, can be self or y.
'''
Expand Down Expand Up @@ -599,7 +600,7 @@ def dtype(self):
else:
dt = 'float%s' % bits
return numpy.dtype(dt)


class ImageData(DataContainer):
'''
Expand Down Expand Up @@ -716,7 +717,7 @@ def get_spacing(self):
arr = numpy.ndarray((3,), dtype = numpy.float32)
try_calling (pysirf.cSIRF_GeomInfo_get_spacing(self.handle, arr.ctypes.data))
return tuple(arr)

def get_size(self):
"""Size is the number of voxels in each dimension."""
arr = numpy.ndarray((3,), dtype = cpp_int_dtype())
Expand All @@ -734,3 +735,20 @@ def get_index_to_physical_point_matrix(self):
arr = numpy.ndarray((4,4), dtype = numpy.float32)
try_calling (pysirf.cSIRF_GeomInfo_get_index_to_physical_point_matrix(self.handle, arr.ctypes.data))
return arr

class AdjointOperator(object):
"""
Creates the adjoint operator of a linear operator `lin_op`.
"""
def __init__(self, operator):
self.operator = operator

def forward(self, x):
"""Calls the adjoint method of the original linear operator"""
# Note: calling `adjoint` will raise an error in SIRF if the operator is not linear.
return self.operator.adjoint(x)

def backward(self, x):
"""Calls the `direct` method of the original linear operator"""
# Note: calling `direct` will raise an error in SIRF if the operator is not linear.
return self.operator.direct(x)
22 changes: 22 additions & 0 deletions src/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#========================================================================
# Author: Evgueni Ovtchinnikov
# Copyright 2020 University College London
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#=========================================================================


if (BUILD_PYTHON)
INSTALL(FILES torch.py DESTINATION "${PYTHON_DEST}/sirf")
endif()
101 changes: 101 additions & 0 deletions src/torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SIRF-PyTorch Wrapper
This wrapper provides a bridge between the [SIRF](https://github.com/SyneRBI/SIRF) (Synergistic Image Reconstruction Framework) library and [PyTorch](https://github.com/pytorch/pytorch), enabling the use of SIRF's image reconstruction operators and objective functions within PyTorch's automatic differentiation (autodiff) framework.

## Usage and Use Cases

The `sirf.torch.Operator`, `sirf.torch.ObjectiveFunction`, and `sirf.torch.ObjectiveFunctionGradient` classes are designed to be used as standard PyTorch `nn.Module`s. You would initialise them with the appropriate SIRF objects and then use them in your forward pass like any other PyTorch layer.

`tests/use_cases.py` demonstrates `sirf.torch` integration in PyTorch with minimal 2D PET examples:

* **Learned Primal-Dual:** Implements a learned primal-dual network for PET image reconstruction, showcasing the use of `sirf.torch.Operator` for handling the forward and adjoint projection operations.
* **PET Variational Network (PETVarNet):** Demonstrates a variational network approach, combining convolutional blocks with gradient information from a SIRF objective function using `sirf.torch.ObjectiveFunctionGradient`.
* **ADAM Gradient Descent Comparison:** Compares two gradient descent implementations: one leveraging the `sirf.torch.Operator` for the acquisition model within the loss calculation, and another directly utilising the `sirf.torch.ObjectiveFunction` for a more traditional optimisation approach. This highlights the flexibility of the wrapper in different optimisation strategies.

## Dimensions used by the wrapper

The wrappers prioritise SIRF's data formats, meaning that the torch arrays must have the shape:
* [batch, [channel,] *SIRF.DataContainer.shape], where the channel dimension is optional.

This requires the **user** to ensure the dimensionality to match between layers.

### Example dimension manipulation
For example, a sinogram in SIRF has shape [tof bins, sinograms, views, tang pos]. For a single non-tof sinogram this is [1, 1, views, tang pos]. The expected torch tensor shape for this wrapper is [batch, [channel,] 1, 1, views, tang pos]. On the otherhand a 2D convolution requires [batch, [channel,] height, width].

```python
conv_1 = torch.nn.Conv2D()
conv_2 = torch.nn.Conv2D()
adjoint_operator = sirf.SIRF_torch.Operator(sirf.AdjointOperator(acquisition_model))
y # sinogram of dimension [batch, [channel,] views, tang pos]

y_filtered = conv_1(y) # filtered sinogram of dimension [batch, [channel,] views, tang pos]
y_filtered = y_filtered.unsqueeze(-3).unsqueeze(-3) # filtered sinogram of dimension [batch, [channel,] 1, 1, views, tang pos]
x_bp = adjoint_operator(y_filtered) # back-projected image of dimension [batch, [channel,] 1, height, width]
x_bp = x_bp.squeeze(-3) # back-projected image of dimension [batch, [channel,] height, width]
x_bp_filtered = conv_2(x_bp) # filtered back-projected image of dimension [batch, [channel,] height, width]
```

## Forward and backward clarification

The use of the terms forward and backward have different meaning given the context:
* Automatic differentiation: Forward (tangent) mode autodiff computes the Jacobian-Vector-Product (JVP). This propagates derivatives forward along with the function evaluation. Backward (or reverse/adjoint) mode autodiff is the Vector-Jacobian-Product (VJP) that propagates derivative information in the reverse direction of the function's evaluation.
* Backward autodiff cont.: Forward pass evaluates the function saving intermediate values. Backward pass uses the chain rule and intermediate values computing the derivatives in the reverse direction with the VJP.
* `torch.autograd.Function`: the `forward` method (forward pass) is the function evaluation. The `backward` method (backward pass) computes the VJP. More specifically, the `backward(*grad_output)` method multiplies the `grad_output` which represents the gradient(s) of a subsequent function/operator (evaluated at the output of `forward`), via chain-rule, by the adjoint of the Jacobian of the `forward` method.

This SIRF-PyTorch wrapper is **only** for reverse-mode automatic differentiation via subclassing `torch.autograd.Function`.

## Wrapper Design

The wrapper provides three main classes:

1. `sirf.torch.Operator`: Wraps a SIRF `Operator` (e.g., a projection operator). Applies the operator forward pass, and applies the adjoint of the Jacobian in backward pass.
2. `sirf.torch.ObjectiveFunction`: Wraps a SIRF `ObjectiveFunction` for computing its value in the forward pass, and multiplying with the objective function gradient in the backward pass.
3. `sirf.torch.ObjectiveFunctionGradient`: Wraps a SIRF `ObjectiveFunction` that computes the objective function gradients in the forward pass and the Hessian-vector product in the backward pass. In the backward the Hessian is evaluated at the point which the objective function's gradient was evaluated.

These classes use custom `torch.autograd.Function` implementations (`_Operator`, `_ObjectiveFunction`, and `_ObjectiveFunctionGradient`) to define the forward and backward passes, handling the conversions between PyTorch tensors and SIRF objects.

### `_Operator` (Forward and Backward Passes)

* **Forward Pass:**
1. Converts the input PyTorch tensor to a SIRF object.
2. Applies the SIRF `Operator.forward()` method.
3. Converts the result back to a PyTorch tensor.
4. If the input tensor requires gradients, it saves relevant information (the output SIRF object and the operator) in the context (`ctx`) for use in the backward pass.

* **Backward Pass (VJP):**
1. Receives the "upstream gradient" (`grad_output`).
2. Converts `grad_output` to a SIRF object.
3. Applies the SIRF `Operator.backward()` method. This will apply the **Jacobian adjoint** of the operator to upstream gradient (the vector).
4. Converts the resulting SIRF object back to a PyTorch tensor and returns it.

### `_ObjectiveFunction` (Forward and Backward Passes)

* **Forward Pass:**
1. Converts the input PyTorch tensor (representing an image for instance) to a SIRF object.
2. Calls the SIRF `ObjectiveFunction.__call__()` method.
3. Returns the of the objective function value as a PyTorch tensor.
4. Saves relevant information to the `ctx` if gradients are needed.

* **Backward Pass (VJP):**
1. Receives the upstream gradient (`grad_output`), in this case it is always a scalar.
2. Gets the gradient of the objective function using `sirf_obj_func.gradient()`, which computed at the input and multiplied by the upstream gradient.
3. Converts the SIRF gradient to a PyTorch tensor.
4. Returns the gradient multiplied by `grad_output`.


### `_ObjectiveFunctionGradient` (Forward and Backward Passes)

* **Forward Pass:**
1. Converts the input PyTorch tensor to a SIRF object.
2. Computes the *gradient* of the SIRF objective function using `sirf_obj_func.gradient()`, which is computed on the input.
3. Returns the gradient as a PyTorch tensor.

* **Backward Pass (VJP):**
1. Receives the upstream gradient (`grad_output`), which now represents a *vector* (not a scalar) of the same shape as the output of `forward`.
2. Converts `grad_output` to a SIRF object.
3. Multiples the Hessian evaluated at the input of `forward` with the "upstream gradient" using `sirf_obj_func.multiply_with_Hessian()`.
4. Returns the Hessian multiplied with a vector as a tensor.

# TODO

* Extend to subsets in the wrapper
* Extend objective functions that vary between batch items
Loading
Loading