Skip to content

Adding PyTorch functionality to SIRF #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
789f19a
Perhaps the wrapper here?
Imraj-Singh Feb 15, 2025
e7659e9
Prototyping a solution to the problem of using a dataset with an acqu…
Imraj-Singh Feb 17, 2025
cd0c207
Adding a small test and cleaning things up, getting a segfault with t…
Imraj-Singh Feb 18, 2025
b97ddde
adding a readme, and starting to test with Gadgetron (to no avail)
Imraj-Singh Feb 18, 2025
9df22d5
Update README.md
Imraj-Singh Feb 18, 2025
e4ce261
Small changes to be tested [ci skip]
Imraj-Singh Feb 19, 2025
ca50798
Some more changes to mainly AcquisitionModelForward [ci skip]
Imraj-Singh Feb 19, 2025
0380fa5
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torc…
Imraj-Singh Feb 19, 2025
9866a33
first test optimising pet objectives in torch, and using the torch pn…
Imraj-Singh Feb 21, 2025
f81b974
objective function working with pet [ci skip]
Imraj-Singh Feb 21, 2025
01084cc
removing some redundancy
Imraj-Singh Feb 21, 2025
9b130b2
update name [ci skip]
Imraj-Singh Feb 21, 2025
b8cbb64
acq_mdl -> operator, removing kwargs, trailing _ for inplace
Imraj-Singh Feb 26, 2025
dc20472
tidying and making more generic
Imraj-Singh Feb 26, 2025
b0bd284
updated read me [ci skip]
Imraj-Singh Feb 26, 2025
409544f
Objective gradient scaffold [ci skip]
Imraj-Singh Feb 26, 2025
9c31b2d
adding better way of specifying the jacobian adjoint [ci skip]
Imraj-Singh Feb 27, 2025
33df7ef
Updating naming, added adjoint operator to sirf as discussed in meeti…
Imraj-Singh Feb 27, 2025
4f67188
updating cmake removing duplicate [ci skip]
Imraj-Singh Feb 27, 2025
fe75bfe
Circular import nonsense
Imraj-Singh Feb 27, 2025
7930ac6
adjoint method doesn't check if linear at the moment, adding that to …
Imraj-Singh Feb 28, 2025
7378a47
sirf torch error checking updates
Imraj-Singh Feb 28, 2025
8c4f110
starting some gradchecks, failing with sirf obj func
Imraj-Singh Feb 28, 2025
2d25953
nondet_tol needs to be high for obj func for some reason... [ci skip]
Imraj-Singh Feb 28, 2025
e7e2dc8
testing obj func grad
Imraj-Singh Mar 2, 2025
60bd4b5
refactoring to abstract common functionality
Imraj-Singh Mar 2, 2025
f72a561
added more gradchecks for hessian and also wrapped acquisition model …
Imraj-Singh Mar 2, 2025
9aca992
3D doesn't work as nondet tol too low, issues with PNLL numerical Jac…
Imraj-Singh Mar 2, 2025
2b9cea6
Slight change in naming, more abstracting
Imraj-Singh Mar 2, 2025
d097e17
Use cases nearly done [ci skip]
Imraj-Singh Mar 3, 2025
70d6e3f
Apply suggestions from code review
Imraj-Singh Mar 3, 2025
262e11f
added documentation and pet use cases
Imraj-Singh Mar 3, 2025
2c68b87
trying some pytests [ci skip]
Imraj-Singh Mar 3, 2025
70d147b
removing clamping [ci skip]
Imraj-Singh Mar 3, 2025
dbe05e2
Making the tols even more lax for obj_fun... [ci skip]
Imraj-Singh Mar 3, 2025
cdde96f
updated readme mainly [ci skip]
Imraj-Singh Mar 4, 2025
b0ab90b
Apply suggestions from Kris' code review
Imraj-Singh Mar 6, 2025
21cdc4f
added a todo
Imraj-Singh Mar 6, 2025
26bd325
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Mar 6, 2025
576fbbc
readme changes [ci skip]
Imraj-Singh Mar 6, 2025
c7a359a
Update src/torch/SIRF_torch.py [ci skip]
Imraj-Singh Mar 6, 2025
3220514
removing test cases [ci skip]
Imraj-Singh Mar 6, 2025
31de795
more readme changes [ci skip]
Imraj-Singh Mar 7, 2025
909b1de
MR grad checks working, getting sinister segfault for obj func [ci skip]
Imraj-Singh Mar 7, 2025
142a338
changing the init to allow for sirf.blah [ci skip]
Imraj-Singh Mar 7, 2025
97e2439
calling for value of obj_func and removing negative [ci skip]
Imraj-Singh Mar 7, 2025
b2b603d
updating the README to explain dimensionality and remove negativity r…
Imraj-Singh Mar 7, 2025
69470b5
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Mar 7, 2025
a788021
trying (probably failing) to make the imports better using that __all…
Imraj-Singh Mar 7, 2025
d1fa40b
updating the naming to be less verbose [ci skip]
Imraj-Singh Mar 7, 2025
0dd3770
moving gradcheck and use_cases so not to get circular imports... [ci …
Imraj-Singh Mar 7, 2025
1049818
minimal varnet [ci skip]
Imraj-Singh Mar 7, 2025
a6f2282
Apply suggestions from code review [ci skip]
Imraj-Singh Mar 9, 2025
bef9f19
take away nonsense from __init__ and update README [ci skip]
Imraj-Singh Mar 9, 2025
15179aa
Updating use cases [ci skip]
Imraj-Singh Mar 12, 2025
1893fa6
requires_grad=False for tensors returned by backward, changes.md upda…
Imraj-Singh Mar 13, 2025
362875b
add image to gd use_case [ci skip]
Imraj-Singh Mar 13, 2025
910fa2c
misc review fixups
casperdcl Mar 27, 2025
0a41e5b
Add once_differentiable decorator [ci skip]
Imraj-Singh Apr 6, 2025
7d27c80
Merge branch 'SyneRBI:master' into torch
Imraj-Singh Apr 6, 2025
0c91764
codacity changes [ci skip]
Imraj-Singh Apr 6, 2025
c0adc9c
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh Apr 6, 2025
a7348ce
codacy changes [ci skip]
Imraj-Singh Apr 6, 2025
89204a8
run ci things
Imraj-Singh Apr 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ else()
endif()

ADD_SUBDIRECTORY(common)
ADD_SUBDIRECTORY(torch)
2 changes: 1 addition & 1 deletion src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ if (BUILD_PYTHON)
INSTALL(TARGETS ${SWIG_MODULE_pysirf_REAL_NAME} DESTINATION "${PYTHON_DEST}/sirf")
INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/pysirf.py" DESTINATION "${PYTHON_DEST}/sirf")
#file(GLOB PythonFiles "${CMAKE_CURRENT_LIST_DIR}/*.py")
set(PythonFiles select_module.py show_image.py SIRF.py Utilities.py SIRF_torch.py)
set(PythonFiles select_module.py show_image.py SIRF.py Utilities.py)
INSTALL(FILES ${PythonFiles} DESTINATION "${PYTHON_DEST}/sirf")
endif()

Expand Down
6 changes: 0 additions & 6 deletions src/common/SIRF_torch.py

This file was deleted.

23 changes: 23 additions & 0 deletions src/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#========================================================================
# Author: Evgueni Ovtchinnikov
# Copyright 2020 University College London
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#=========================================================================


if (BUILD_PYTHON)
set(PythonFiles SIRF_torch.py)
INSTALL(FILES ${PythonFiles} DESTINATION "${PYTHON_DEST}/sirf")
endif()
123 changes: 123 additions & 0 deletions src/torch/SIRF_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
try:
import torch
except ImportError:
raise ImportError('Failed to import torch. Please install PyTorch first.')

# based on
# https://github.com/educating-dip/pet_deep_image_prior/blob/main/src/deep_image_prior/torch_wrapper.py


class _objectiveFunctionModule3D(torch.autograd.Function):
@staticmethod
def forward( ctx, x, image_template, sirf_obj):
ctx.device = x.device
ctx.sirf_obj = sirf_obj
ctx.image_template = image_template
ctx.x = x.detach().cpu().numpy().squeeze()
ctx.x = ctx.image_template.fill(ctx.x)
value_np = ctx.sirf_obj.get_value(ctx.x)
return torch.tensor(value_np).to(ctx.device)

@staticmethod
def backward(
ctx,
in_grad):
grads_np = ctx.sirf_obj.get_gradient(ctx.x).as_array()
grads = torch.from_numpy(grads_np).to(ctx.device) * in_grad
return grads.unsqueeze(dim=0), None, None, None

class _AcquisitionModelForward(torch.autograd.Function):
@staticmethod
def forward(ctx, x, image_template, data_template, sirf_acq_model):
ctx.sirf_acq_model = sirf_acq_model
ctx.image_template = image_template
ctx.data_template = data_template
x_sirf = sirf_to_torch(x, image_template)
x_np = x.detach().cpu().numpy()
x_np = ctx.image_template.fill(x_np[None])
proj_data_np = ctx.sirf_acq_model.forward(x_np).as_array()
proj_data = torch.from_numpy(proj_data_np).requires_grad_().to(x.device)
return proj_data

@staticmethod
def backward(ctx, data):
data_np = data.detach().cpu().numpy()
data_np = ctx.data_template.fill(data_np)
grads_np = ctx.sirf_acq_model.backward(data_np).as_array()
grads = torch.from_numpy(grads_np).requires_grad_().to(data.device)
return grads, None, None, None, None

class AcquisitionFactory():
""" I am not too sure this is a good idea...
So the idea here is the following:
- We have a sirf object that is the acquisition model
- This acquisition model has the same geometry, but may have additional
components that vary between samples from the dataset.
We need to be able to choose the correct the acquisition model wrapper based
on the data, also there is the assumption that we have all the same data
components for every sample in the dataset, but heyho this is a start."""
def __init__(self, acq_model, sirf_data_in_template, sirf_data_out_template, device, *args):
# find a way of checking if the input or output is an image
# then throw a warning and choose whether the wrapper should be forward
# or backward
if len(args) > 0:
# detemine the modality of the data
self.modality = args[0]
if len(args) > 1:
# determine the additional components of the data
self.second_arg = args[1]
# ... and so on
# then begin to assign the correct acquisition model wrapper



class DatasetAcquisitionModels(torch.nn.Module):
"""Class that uses the acquisition model wrapper, to separate the geometric
forward model from components that are data-dependent. This has the is meant
for data of the dimensions [batch, channel, *sirf_template.shape], where
the sirf template is either the image or the measurement template."""
def __init__(self, acq_model, sirf_data_in_template, sirf_data_out_template, device):
super(DatasetAcquisitionModels, self).__init__()
# PET - multiplicative and additive factors
# MR - coils maps
# if acq_model is a sirf object from stir
# if acq_model has scatter: scatter == True, else scatter == False
# if acq_model has attn: attn == True, else attn == False
# etc etc
# self.acq_model = AcquisitionFactory(acq_model, sirf_image_template,
# sirf_measurement_template, devicePET, scatter, attn, device)
# elif acq_model is a sirf object from gadgetron
# if acq_model has coil_sens: coil_sens == True, else coil_sens == False
# etc etc
# self.acq_model = AcquisitionFactory(acq_model, sirf_image_template,
# sirf_measurement_template, deviceMR, coil_sens, device)
# else:
# raise ValueError('acq_model must be a sirf object from STIR or Gadgetron')
self.acq_model = acq_model
self.sirf_data_in_template = sirf_data_in_template
self.sirf_data_out_template = sirf_data_out_template
# have a forward that can take a dynamic amount of arguments
def forward(self, data_in, *args):
# check if data of form [batch, channel, *sirf_data_in_template.shape]
if data.shape[2:] != sirf_data_in_template.shape:
raise ValueError('Data must be of the form [batch, channel, *sirf_data_in_template.shape]')

data_out = torch.zeros(data_in.shape[0], data_in.shape[1], *sirf_data_out_template.shape)
for i in data.shape[0]:
for j in data.shape[1]:
if len(args) > 0:
additional_data = []
for k in range(len(args)):
additional_data.append(args[k][i, j])
data_out[i, j] = self.acq_model(data_in[i, j], *additional_data)
return self.acq_model(*args)



# Objective function
# Acquistion model
# With torch functionals and comparison
# Jacobian vector for STIR
# similar thing for Gadgetron
# adjointness test
# gradcheck, check traceability