Skip to content

Adding PyTorch functionality to SIRF #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
789f19a
Perhaps the wrapper here?
Imraj-Singh Feb 15, 2025
e7659e9
Prototyping a solution to the problem of using a dataset with an acqu…
Imraj-Singh Feb 17, 2025
cd0c207
Adding a small test and cleaning things up, getting a segfault with t…
Imraj-Singh Feb 18, 2025
b97ddde
adding a readme, and starting to test with Gadgetron (to no avail)
Imraj-Singh Feb 18, 2025
9df22d5
Update README.md
Imraj-Singh Feb 18, 2025
e4ce261
Small changes to be tested [ci skip]
Imraj-Singh Feb 19, 2025
ca50798
Some more changes to mainly AcquisitionModelForward [ci skip]
Imraj-Singh Feb 19, 2025
0380fa5
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torc…
Imraj-Singh Feb 19, 2025
9866a33
first test optimising pet objectives in torch, and using the torch pn…
Imraj-Singh Feb 21, 2025
f81b974
objective function working with pet [ci skip]
Imraj-Singh Feb 21, 2025
01084cc
removing some redundancy
Imraj-Singh Feb 21, 2025
9b130b2
update name [ci skip]
Imraj-Singh Feb 21, 2025
b8cbb64
acq_mdl -> operator, removing kwargs, trailing _ for inplace
Imraj-Singh Feb 26, 2025
dc20472
tidying and making more generic
Imraj-Singh Feb 26, 2025
b0bd284
updated read me [ci skip]
Imraj-Singh Feb 26, 2025
409544f
Objective gradient scaffold [ci skip]
Imraj-Singh Feb 26, 2025
9c31b2d
adding better way of specifying the jacobian adjoint [ci skip]
Imraj-Singh Feb 27, 2025
33df7ef
Updating naming, added adjoint operator to sirf as discussed in meeti…
Imraj-Singh Feb 27, 2025
4f67188
updating cmake removing duplicate [ci skip]
Imraj-Singh Feb 27, 2025
fe75bfe
Circular import nonsense
Imraj-Singh Feb 27, 2025
7930ac6
adjoint method doesn't check if linear at the moment, adding that to …
Imraj-Singh Feb 28, 2025
7378a47
sirf torch error checking updates
Imraj-Singh Feb 28, 2025
8c4f110
starting some gradchecks, failing with sirf obj func
Imraj-Singh Feb 28, 2025
2d25953
nondet_tol needs to be high for obj func for some reason... [ci skip]
Imraj-Singh Feb 28, 2025
e7e2dc8
testing obj func grad
Imraj-Singh Mar 2, 2025
60bd4b5
refactoring to abstract common functionality
Imraj-Singh Mar 2, 2025
f72a561
added more gradchecks for hessian and also wrapped acquisition model …
Imraj-Singh Mar 2, 2025
9aca992
3D doesn't work as nondet tol too low, issues with PNLL numerical Jac…
Imraj-Singh Mar 2, 2025
2b9cea6
Slight change in naming, more abstracting
Imraj-Singh Mar 2, 2025
d097e17
Use cases nearly done [ci skip]
Imraj-Singh Mar 3, 2025
70d6e3f
Apply suggestions from code review
Imraj-Singh Mar 3, 2025
262e11f
added documentation and pet use cases
Imraj-Singh Mar 3, 2025
2c68b87
trying some pytests [ci skip]
Imraj-Singh Mar 3, 2025
70d147b
removing clamping [ci skip]
Imraj-Singh Mar 3, 2025
dbe05e2
Making the tols even more lax for obj_fun... [ci skip]
Imraj-Singh Mar 3, 2025
cdde96f
updated readme mainly [ci skip]
Imraj-Singh Mar 4, 2025
b0ab90b
Apply suggestions from Kris' code review
Imraj-Singh Mar 6, 2025
21cdc4f
added a todo
Imraj-Singh Mar 6, 2025
26bd325
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Mar 6, 2025
576fbbc
readme changes [ci skip]
Imraj-Singh Mar 6, 2025
c7a359a
Update src/torch/SIRF_torch.py [ci skip]
Imraj-Singh Mar 6, 2025
3220514
removing test cases [ci skip]
Imraj-Singh Mar 6, 2025
31de795
more readme changes [ci skip]
Imraj-Singh Mar 7, 2025
909b1de
MR grad checks working, getting sinister segfault for obj func [ci skip]
Imraj-Singh Mar 7, 2025
142a338
changing the init to allow for sirf.blah [ci skip]
Imraj-Singh Mar 7, 2025
97e2439
calling for value of obj_func and removing negative [ci skip]
Imraj-Singh Mar 7, 2025
b2b603d
updating the README to explain dimensionality and remove negativity r…
Imraj-Singh Mar 7, 2025
69470b5
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Mar 7, 2025
a788021
trying (probably failing) to make the imports better using that __all…
Imraj-Singh Mar 7, 2025
d1fa40b
updating the naming to be less verbose [ci skip]
Imraj-Singh Mar 7, 2025
0dd3770
moving gradcheck and use_cases so not to get circular imports... [ci …
Imraj-Singh Mar 7, 2025
1049818
minimal varnet [ci skip]
Imraj-Singh Mar 7, 2025
a6f2282
Apply suggestions from code review [ci skip]
Imraj-Singh Mar 9, 2025
bef9f19
take away nonsense from __init__ and update README [ci skip]
Imraj-Singh Mar 9, 2025
15179aa
Updating use cases [ci skip]
Imraj-Singh Mar 12, 2025
1893fa6
requires_grad=False for tensors returned by backward, changes.md upda…
Imraj-Singh Mar 13, 2025
362875b
add image to gd use_case [ci skip]
Imraj-Singh Mar 13, 2025
910fa2c
misc review fixups
casperdcl Mar 27, 2025
0a41e5b
Add once_differentiable decorator [ci skip]
Imraj-Singh Apr 6, 2025
7d27c80
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Apr 6, 2025
0c91764
codacity changes [ci skip]
Imraj-Singh Apr 6, 2025
c0adc9c
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Apr 6, 2025
a7348ce
codacy changes [ci skip]
Imraj-Singh Apr 6, 2025
89204a8
run ci things
Imraj-Singh Apr 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions src/torch/tests/varnet_minimal.py
Copy link
Contributor Author

@Imraj-Singh Imraj-Singh Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gschramm this a minimal reproduction of your varnet. I am getting minuses in my projected input still. I think this is expected as the relu only ensures non-negativity in the forward pass.

Given the loss function there can be negative elements in the gradient even if the the relu is enforcing non-negativity in the forward pass.

@KrisThielemans this may mean that we require negative values in the input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to UCL/STIR#1477

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes of course, the gradient of the obj function will generally have positive and negative elements.

Fixing the STIR issue will take me a while though. Help welcome :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case we can't backpropagate through the gradient of the objective... right? We should remove the varnet example or through a warning/error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment saying that this need resolving the STIR issue, but leave it in, as it's instructive.

Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@

from sirf.torch import Operator, ObjectiveFunction, ObjectiveFunctionGradient, sirf_to_torch
import sirf.STIR as pet
pet.set_verbosity(1)
pet.AcquisitionData.set_storage_scheme("memory")
import sirf
msg = sirf.STIR.MessageRedirector(info=None, warn=None, errr=None)
from sirf.Utilities import examples_data_path
import matplotlib.pyplot as plt
import torch

pet_data_path = examples_data_path('PET')
pet_2d_raw_data_file = pet.existing_filepath(pet_data_path, 'thorax_single_slice/template_sinogram.hs')
pet_2d_acq_data = pet.AcquisitionData(pet_2d_raw_data_file)
pet_2d_init_image_file = pet.existing_filepath(pet_data_path, 'thorax_single_slice/emission.hv')
pet_2d_image_data = pet.ImageData(pet_2d_init_image_file)
pet_2d_acq_model = pet.AcquisitionModelUsingParallelproj()
data_processor = pet.TruncateToCylinderProcessor()
pet_2d_acq_model.set_image_data_processor(data_processor)
pet_2d_acq_model.set_up(pet_2d_acq_data, pet_2d_image_data)
sens_img = pet_2d_acq_model.backward(pet_2d_acq_data.get_uniform_copy(1.0))
inv_sens_img = sens_img.power(-1)
data_processor.apply(inv_sens_img)
pet_2d_acq_data = pet_2d_acq_model.forward(pet_2d_image_data) + 10.5
pet_2d_acq_model.set_background_term(pet_2d_acq_data.get_uniform_copy(10.5))
pet_2d_acq_model.set_up(pet_2d_acq_data, pet_2d_image_data)

acq_data = pet_2d_acq_data
acq_model = pet_2d_acq_model
image_data = pet_2d_image_data
obj_fun = pet.make_Poisson_loglikelihood(acq_data)
obj_fun.set_acquisition_model(acq_model)
obj_fun.set_up(image_data)
objfuncgrad = ObjectiveFunctionGradient(obj_fun, image_data)

dev = "cuda" if torch.cuda.is_available() else "cpu"

cnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, 5, padding="same", bias=False),
torch.nn.Conv2d(5, 5, 5, padding="same", bias=False),
torch.nn.PReLU(device=dev),
torch.nn.Conv2d(5, 5, 5, padding="same", bias=False),
torch.nn.Conv2d(5, 5, 5, padding="same", bias=False),
torch.nn.PReLU(device=dev),
torch.nn.Conv2d(5, 1, 1, padding="same", bias=False),
).to(dev)


class UnrolledOSEMVarNet(torch.nn.Module):
def __init__(
self,
objective_function_gradient: sirf.torch.ObjectiveFunctionGradient,
inv_sens_img: torch.Tensor,
convnet: torch.nn.Module,
device: str,
) -> None:
"""Unrolled OSEM Variational Network with 2 blocks

Parameters
----------
objective_function : sirf.STIR objetive function
(listmode) Poisson logL objective function
that we use for the OSEM updates
sirf_template_image : sirf.STIR.ImageData
used for the conversion between torch tensors and sirf images
convnet : torch.nn.Module
a (convolutional) neural network that maps a minibatch tensor
of shape [1,1,spatial_dimensions] onto a minibatch tensor of the same shape
device : str
device used for the calculations
"""
super().__init__()

# OSEM update layer using the 1st subset of the listmode data
self.objective_function_gradient = objective_function_gradient
self._inv_sens_img = inv_sens_img

self._convnet = convnet
self._relu = torch.nn.ReLU()

# trainable parameters for the fusion of the OSEM update and the CNN output in the two blocks
# we start with a weight of 10 for the fusion
# a good starting value depends on the scale of the input image
self._fusion_weight0 = torch.nn.Parameter(
10 * torch.ones(1, device=device, dtype=torch.float32)
)
self._fusion_weight1 = torch.nn.Parameter(
10 * torch.ones(1, device=device, dtype=torch.float32)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self._relu(
self._fusion_weight0 * self._convnet(x) + self.objective_function_gradient(x)*x*self._inv_sens_img
)
x2 = self._relu(
self._fusion_weight1 * self._convnet(x1) + self.objective_function_gradient(x1)*x1*self._inv_sens_img
)

return x2

inv_sens_img = torch.tensor(inv_sens_img.as_array()).unsqueeze(0).to(dev)
varnet = UnrolledOSEMVarNet(objfuncgrad, inv_sens_img, cnn, dev)
varnet.to(dev)

# use seed for reproducibility
torch.manual_seed(42)

torch_image = torch.tensor(image_data.as_array()).unsqueeze(0).to(dev)
# add gaussian noise to the image
torch_input = torch_image + 10 * torch.randn_like(torch_image)
# clip the input to be non-negative
torch_input = torch.clamp(torch_input, 0).detach()
# plot non noisy image and noisy image
plt.imshow(torch_image.detach().cpu().numpy()[0,0])
plt.colorbar()
plt.savefig("pet_image.png")
plt.close()
plt.imshow(torch_input.detach().cpu().numpy()[0,0])
plt.colorbar()
plt.savefig("pet_noisy_image.png")
plt.close()

# set up the optimizer
optimizer = torch.optim.Adam(varnet.parameters(), lr=1e-4)
relu = torch.nn.ReLU()
loss = torch.nn.MSELoss()
for i in range(100):
optimizer.zero_grad()
loss_val = loss(relu(varnet(torch_input)),torch_image)
loss_val.backward()
optimizer.step()
print("Iteration: ", i, "Loss: ", loss_val.item())
out = relu(varnet(torch_input))
plt.imshow(out.detach().cpu().numpy()[0,0])
plt.savefig("pet_varnet.png")