Skip to content

Adding PyTorch functionality to SIRF #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
789f19a
Perhaps the wrapper here?
Imraj-Singh Feb 15, 2025
e7659e9
Prototyping a solution to the problem of using a dataset with an acqu…
Imraj-Singh Feb 17, 2025
cd0c207
Adding a small test and cleaning things up, getting a segfault with t…
Imraj-Singh Feb 18, 2025
b97ddde
adding a readme, and starting to test with Gadgetron (to no avail)
Imraj-Singh Feb 18, 2025
9df22d5
Update README.md
Imraj-Singh Feb 18, 2025
e4ce261
Small changes to be tested [ci skip]
Imraj-Singh Feb 19, 2025
ca50798
Some more changes to mainly AcquisitionModelForward [ci skip]
Imraj-Singh Feb 19, 2025
0380fa5
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torc…
Imraj-Singh Feb 19, 2025
9866a33
first test optimising pet objectives in torch, and using the torch pn…
Imraj-Singh Feb 21, 2025
f81b974
objective function working with pet [ci skip]
Imraj-Singh Feb 21, 2025
01084cc
removing some redundancy
Imraj-Singh Feb 21, 2025
9b130b2
update name [ci skip]
Imraj-Singh Feb 21, 2025
b8cbb64
acq_mdl -> operator, removing kwargs, trailing _ for inplace
Imraj-Singh Feb 26, 2025
dc20472
tidying and making more generic
Imraj-Singh Feb 26, 2025
b0bd284
updated read me [ci skip]
Imraj-Singh Feb 26, 2025
409544f
Objective gradient scaffold [ci skip]
Imraj-Singh Feb 26, 2025
9c31b2d
adding better way of specifying the jacobian adjoint [ci skip]
Imraj-Singh Feb 27, 2025
33df7ef
Updating naming, added adjoint operator to sirf as discussed in meeti…
Imraj-Singh Feb 27, 2025
4f67188
updating cmake removing duplicate [ci skip]
Imraj-Singh Feb 27, 2025
fe75bfe
Circular import nonsense
Imraj-Singh Feb 27, 2025
7930ac6
adjoint method doesn't check if linear at the moment, adding that to …
Imraj-Singh Feb 28, 2025
7378a47
sirf torch error checking updates
Imraj-Singh Feb 28, 2025
8c4f110
starting some gradchecks, failing with sirf obj func
Imraj-Singh Feb 28, 2025
2d25953
nondet_tol needs to be high for obj func for some reason... [ci skip]
Imraj-Singh Feb 28, 2025
e7e2dc8
testing obj func grad
Imraj-Singh Mar 2, 2025
60bd4b5
refactoring to abstract common functionality
Imraj-Singh Mar 2, 2025
f72a561
added more gradchecks for hessian and also wrapped acquisition model …
Imraj-Singh Mar 2, 2025
9aca992
3D doesn't work as nondet tol too low, issues with PNLL numerical Jac…
Imraj-Singh Mar 2, 2025
2b9cea6
Slight change in naming, more abstracting
Imraj-Singh Mar 2, 2025
d097e17
Use cases nearly done [ci skip]
Imraj-Singh Mar 3, 2025
70d6e3f
Apply suggestions from code review
Imraj-Singh Mar 3, 2025
262e11f
added documentation and pet use cases
Imraj-Singh Mar 3, 2025
2c68b87
trying some pytests [ci skip]
Imraj-Singh Mar 3, 2025
70d147b
removing clamping [ci skip]
Imraj-Singh Mar 3, 2025
dbe05e2
Making the tols even more lax for obj_fun... [ci skip]
Imraj-Singh Mar 3, 2025
cdde96f
updated readme mainly [ci skip]
Imraj-Singh Mar 4, 2025
b0ab90b
Apply suggestions from Kris' code review
Imraj-Singh Mar 6, 2025
21cdc4f
added a todo
Imraj-Singh Mar 6, 2025
26bd325
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Mar 6, 2025
576fbbc
readme changes [ci skip]
Imraj-Singh Mar 6, 2025
c7a359a
Update src/torch/SIRF_torch.py [ci skip]
Imraj-Singh Mar 6, 2025
3220514
removing test cases [ci skip]
Imraj-Singh Mar 6, 2025
31de795
more readme changes [ci skip]
Imraj-Singh Mar 7, 2025
909b1de
MR grad checks working, getting sinister segfault for obj func [ci skip]
Imraj-Singh Mar 7, 2025
142a338
changing the init to allow for sirf.blah [ci skip]
Imraj-Singh Mar 7, 2025
97e2439
calling for value of obj_func and removing negative [ci skip]
Imraj-Singh Mar 7, 2025
b2b603d
updating the README to explain dimensionality and remove negativity r…
Imraj-Singh Mar 7, 2025
69470b5
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Mar 7, 2025
a788021
trying (probably failing) to make the imports better using that __all…
Imraj-Singh Mar 7, 2025
d1fa40b
updating the naming to be less verbose [ci skip]
Imraj-Singh Mar 7, 2025
0dd3770
moving gradcheck and use_cases so not to get circular imports... [ci …
Imraj-Singh Mar 7, 2025
1049818
minimal varnet [ci skip]
Imraj-Singh Mar 7, 2025
a6f2282
Apply suggestions from code review [ci skip]
Imraj-Singh Mar 9, 2025
bef9f19
take away nonsense from __init__ and update README [ci skip]
Imraj-Singh Mar 9, 2025
15179aa
Updating use cases [ci skip]
Imraj-Singh Mar 12, 2025
1893fa6
requires_grad=False for tensors returned by backward, changes.md upda…
Imraj-Singh Mar 13, 2025
362875b
add image to gd use_case [ci skip]
Imraj-Singh Mar 13, 2025
910fa2c
misc review fixups
casperdcl Mar 27, 2025
0a41e5b
Add once_differentiable decorator [ci skip]
Imraj-Singh Apr 6, 2025
7d27c80
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Apr 6, 2025
0c91764
codacity changes [ci skip]
Imraj-Singh Apr 6, 2025
c0adc9c
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Apr 6, 2025
a7348ce
codacy changes [ci skip]
Imraj-Singh Apr 6, 2025
89204a8
run ci things
Imraj-Singh Apr 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 75 additions & 37 deletions src/torch/SIRF_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,31 @@
# https://github.com/educating-dip/pet_deep_image_prior/blob/main/src/deep_image_prior/torch_wrapper.py


def sirf_to_torch(sirf_src, torch_dest, requires_grad=False):
def sirf_to_torch(sirf_src, device, requires_grad=False):
if requires_grad:
return torch.tensor(sirf_src.as_array(), requires_grad=True).to(torch_dest.device)
# use torch.tensor to infer data type
return torch.tensor(sirf_src.as_array(), requires_grad=True).to(device)
else:
return torch.tensor(sirf_src.as_array()).to(torch_dest.device)
return torch.tensor(sirf_src.as_array()).to(device)

def torch_to_sirf(torch_src, sirf_dest):
return sirf_dest.fill(torch_src.detach().cpu().numpy())


class _ObjectiveFunctionModule(torch.autograd.Function):
class _ObjectiveFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
torch_image,
sirf_image_template,
sirf_obj_func
):

device = torch_image.device
sirf_image_template = torch_to_sirf(torch_image, sirf_image_template)
value_np = sirf_obj_func.get_value(sirf_image_template).as_array()
if torch_image.requires_grad:
ctx.save_for_backward(torch_image, sirf_image_template, sirf_obj_func)
return torch.tensor(value_np, requires_grad=True).to(torch_image.device)
ctx.save_for_backward(device, sirf_image_template, sirf_obj_func)
return torch.tensor(value_np, requires_grad=True).to(device)
else:
return torch.tensor(value_np).to(torch_image.device)

Expand All @@ -39,70 +41,75 @@ def backward(ctx,
grad_output
):

torch_image, sirf_image_template, sirf_obj_func = ctx.saved_tensors
device, sirf_image_template, sirf_obj_func = ctx.saved_tensors
tmp_grad = sirf_obj.get_gradient(sirf_image_template)
grad = sirf_to_torch(tmp_grad, torch_image, requires_grad=True)
grad = sirf_to_torch(tmp_grad, device, requires_grad=True)
return grad_output*grad, None, None, None

class _AcquisitionModelForward(torch.autograd.Function):
@staticmethod
def forward(ctx,
torch_image,
torch_measurements_template,
sirf_image_template,
sirf_acq_mdl
):

device = torch_image.device
sirf_image_template = torch_to_sirf(torch_image, sirf_image_template)
sirf_forward_projected = sirf_acq_mdl.forward(sirf_image_template)
if torch_image.requires_grad:
ctx.torch_image = torch_image
ctx.device = device
ctx.sirf_forward_projected = sirf_forward_projected
ctx.sirf_acq_mdl = sirf_acq_mdl
return sirf_to_torch(sirf_forward_projected, torch_measurements_template, requires_grad=True)
return sirf_to_torch(sirf_forward_projected, device, requires_grad=True)
else:
return sirf_to_torch(sirf_forward_projected, torch_measurements_template)
return sirf_to_torch(sirf_forward_projected, device)

@staticmethod
def backward(ctx,
grad_output
):
sirf_image = ctx.sirf_acq_mdl.backward(torch_to_sirf(grad_output, ctx.sirf_forward_projected))
grad = sirf_to_torch(sirf_image, ctx.torch_image, requires_grad=True)
return grad, None, None, None, None
grad = sirf_to_torch(sirf_image, ctx.device, requires_grad=True)
return grad, None, None, None



class _AcquisitionModelBackward(torch.autograd.Function):
@staticmethod
def forward(ctx,
torch_measurements,
torch_image_template,
sirf_measurements_template,
sirf_acq_mdl
):

device = torch_measurements.device
sirf_measurements_template = torch_to_sirf(torch_measurements, sirf_measurements_template)
sirf_backward_projected = sirf_acq_mdl.backward(sirf_measurements_template)
if torch_image_template.requires_grad:
ctx.torch_measurements = torch_measurements
if torch_measurements.requires_grad:
ctx.device = device
ctx.sirf_backward_projected = sirf_backward_projected
ctx.sirf_acq_mdl = sirf_acq_mdl
return sirf_to_torch(sirf_backward_projected, torch_image_template, requires_grad=True)
return sirf_to_torch(sirf_backward_projected, device, requires_grad=True)
else:
return sirf_to_torch(sirf_backward_projected, torch_image_template)
return sirf_to_torch(sirf_backward_projected, device)

@staticmethod
def backward(ctx,
grad_output
):

sirf_measurements = ctx.sirf_acq_mdl.forward(torch_to_sirf(grad_output, ctx.sirf_backward_projected))
grad = sirf_to_torch(sirf_measurements, ctx.torch_measurements, requires_grad=True)
grad = sirf_to_torch(sirf_measurements, ctx.device, requires_grad=True)
return grad, None, None, None, None

class AcquisitionModelForward(torch.nn.Module):
def __init__(self, acq_mdl, sirf_image_template, sirf_measurements_template, device = "cpu", ):
def __init__(self,
acq_mdl,
sirf_image_template,
sirf_measurements_template,
device = "cpu",
):
super(AcquisitionModelForward, self).__init__()
# get the shape of image and measurements
self.acq_mdl = acq_mdl
Expand All @@ -111,21 +118,49 @@ def __init__(self, acq_mdl, sirf_image_template, sirf_measurements_template, dev
self.sirf_image_template = sirf_image_template
self.sirf_measurements_template = sirf_measurements_template

self.torch_measurements_template = torch.tensor(sirf_measurements_template.as_array(), requires_grad=False).to(device)*0

def forward(self, torch_image):
# view as torch sometimes doesn't like singleton dimensions
torch_image = torch_image.view(self.sirf_image_shape)
return _AcquisitionModelForward.apply(torch_image, self.torch_measurements_template, self.sirf_image_template, self.acq_mdl).squeeze()


def forward(self, image, **kwargs):
# PyTorch image (2D) is size [batch, channel, height, width] or [batch, height, width]
# PyTorch volume (3D) is size [batch, channel, depth, height, width] or [batch, depth, height, width]
# check keys of kwargs
if not set(kwargs.keys()).issubset({'attenuation', 'background', 'sensitivity','norm','csm'}):
raise ValueError("Invalid keyword arguments, only 'attenuation', 'background', 'sensitivity','norm','csm' are allowed. Assumes lists of same length.")

n_batch = image.shape[0]
if self.sirf_image_shape[0] == 1:
# if 2D
if len(image.shape) == 3:
# if 2D and no channel then add singleton
# add singleton for SIRF
image = image.unsqueeze(1)
else:
# if 3D
if len(image.shape) == 4:
# if 3D and no channel then add singleton
image = image.unsqueeze(1)
n_channel = image.shape[1]

# This looks horrible, but PyTorch should be able to trace.
batch_images = []
for batch in n_batch:
channel_images = []
for channel in n_channel:
torch_image = image[batch, channel]
torch_image = torch_image.view(self.sirf_image_shape)
# if there are kwargs then raise error
if kwargs:
raise NotImplementedError("Keyword arguments are not implemented yet.")
channel_images.append(_AcquisitionModelForward.apply(torch_image, self.sirf_image_template, self.sirf_measurements_template, self.acq_mdl))
batch_images.append(torch.stack(channel_images, dim=0))
# [batch, channel, *forward_projected.shape]
return torch.stack(batch_images, dim=0)

if __name__ == '__main__':
import os
import numpy
# Import the PET reconstruction engine
import sirf.STIR as pet
# Set the verbosity
#pet.set_verbosity(1)
pet.set_verbosity(1)
# Store temporary sinograms in RAM
pet.AcquisitionData.set_storage_scheme("memory")
import sirf
Expand Down Expand Up @@ -155,7 +190,7 @@ def forward(self, torch_image):


print("Comparing the forward projected")
torch_measurements = _AcquisitionModelForward.apply(torch_image, torch_measurements_template, sirf_image_template, sirf_acq_mdl)
torch_measurements = _AcquisitionModelForward.apply(torch_image, sirf_image_template, sirf_acq_mdl)
print("Sum of torch: ", torch_measurements.detach().cpu().numpy().sum(), "Sum of sirf: ", sirf_measurements.sum(), "Sum of Differences: ", numpy.abs(torch_measurements.detach().cpu().numpy() - sirf_measurements.as_array()).sum())
print("Comparing the backward of forward projected")
# TO TEST THAT WE ARE BACKWARDING CORRECTLY RETAIN GRAD AND SUM THEN BACKWARD
Expand All @@ -173,7 +208,7 @@ def forward(self, torch_image):
bp_sirf_measurements = acq_mdl.backward(sirf_measurements)

print("Comparing the backward projected")
torch_image = _AcquisitionModelBackward.apply(torch_measurements, torch_image_template, sirf_measurements_template, sirf_acq_mdl)
torch_image = _AcquisitionModelBackward.apply(torch_measurements, sirf_measurements_template, sirf_acq_mdl)
print("Sum of torch: ", torch_image.detach().cpu().numpy().sum(), "Sum of sirf: ", bp_sirf_measurements.sum(), \
"Sum of Differences: ", numpy.abs(torch_image.detach().cpu().numpy() - bp_sirf_measurements.as_array()).sum())
torch_measurements.retain_grad()
Expand All @@ -184,11 +219,14 @@ def forward(self, torch_image):
numpy.abs(torch_measurements.grad.detach().cpu().numpy() - comparison.as_array()).sum())


""" # form objective function
# form objective function
print("Objective Function Test")
obj_fun = pet.make_Poisson_loglikelihood(sirf_measurements)
obj_fun.set_acquisition_model(acq_mdl)
print("Made poisson loglikelihood")
obj_fun.set_acquisition_model(sirf_acq_mdl)
print("Set acquisition model")
obj_fun.set_up(sirf_image_template)
print("Set up")

torch_image = torch.tensor(image.as_array(), requires_grad=True).to(device)
sirf_image_template = image.get_uniform_copy(0)
Expand All @@ -200,8 +238,8 @@ def forward(self, torch_image):
# grad sum(f(x)) = f'(x)
comparison = obj_fun.gradient(sirf_image_template)
print("Sum of torch: ", torch_image.grad.sum(), "Sum of sirf: ", comparison.sum(), "Sum of Differences: ", (torch_image.grad.detach().cpu().numpy() - comparison.as_array()).sum())
"""
from sirf.Gadgetron import *

""" from sirf.Gadgetron import *
print("2D MRI TEST")
data_path = examples_data_path('MR')
AcquisitionData.set_storage_scheme('memory')
Expand All @@ -227,7 +265,7 @@ def forward(self, torch_image):

am = AcquisitionModel(processed_data, complex_images)
am.set_coil_sensitivity_maps(csms)
print(am)
print(am) """

# Objective function
# Acquistion model
Expand Down