-
Notifications
You must be signed in to change notification settings - Fork 29
Adding PyTorch functionality to SIRF #1305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Imraj-Singh
wants to merge
64
commits into
SyneRBI:master
Choose a base branch
from
Imraj-Singh:torch
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
789f19a
Perhaps the wrapper here?
Imraj-Singh e7659e9
Prototyping a solution to the problem of using a dataset with an acqu…
Imraj-Singh cd0c207
Adding a small test and cleaning things up, getting a segfault with t…
Imraj-Singh b97ddde
adding a readme, and starting to test with Gadgetron (to no avail)
Imraj-Singh 9df22d5
Update README.md
Imraj-Singh e4ce261
Small changes to be tested [ci skip]
Imraj-Singh ca50798
Some more changes to mainly AcquisitionModelForward [ci skip]
Imraj-Singh 0380fa5
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torc…
Imraj-Singh 9866a33
first test optimising pet objectives in torch, and using the torch pn…
Imraj-Singh f81b974
objective function working with pet [ci skip]
Imraj-Singh 01084cc
removing some redundancy
Imraj-Singh 9b130b2
update name [ci skip]
Imraj-Singh b8cbb64
acq_mdl -> operator, removing kwargs, trailing _ for inplace
Imraj-Singh dc20472
tidying and making more generic
Imraj-Singh b0bd284
updated read me [ci skip]
Imraj-Singh 409544f
Objective gradient scaffold [ci skip]
Imraj-Singh 9c31b2d
adding better way of specifying the jacobian adjoint [ci skip]
Imraj-Singh 33df7ef
Updating naming, added adjoint operator to sirf as discussed in meeti…
Imraj-Singh 4f67188
updating cmake removing duplicate [ci skip]
Imraj-Singh fe75bfe
Circular import nonsense
Imraj-Singh 7930ac6
adjoint method doesn't check if linear at the moment, adding that to …
Imraj-Singh 7378a47
sirf torch error checking updates
Imraj-Singh 8c4f110
starting some gradchecks, failing with sirf obj func
Imraj-Singh 2d25953
nondet_tol needs to be high for obj func for some reason... [ci skip]
Imraj-Singh e7e2dc8
testing obj func grad
Imraj-Singh 60bd4b5
refactoring to abstract common functionality
Imraj-Singh f72a561
added more gradchecks for hessian and also wrapped acquisition model …
Imraj-Singh 9aca992
3D doesn't work as nondet tol too low, issues with PNLL numerical Jac…
Imraj-Singh 2b9cea6
Slight change in naming, more abstracting
Imraj-Singh d097e17
Use cases nearly done [ci skip]
Imraj-Singh 70d6e3f
Apply suggestions from code review
Imraj-Singh 262e11f
added documentation and pet use cases
Imraj-Singh 2c68b87
trying some pytests [ci skip]
Imraj-Singh 70d147b
removing clamping [ci skip]
Imraj-Singh dbe05e2
Making the tols even more lax for obj_fun... [ci skip]
Imraj-Singh cdde96f
updated readme mainly [ci skip]
Imraj-Singh b0ab90b
Apply suggestions from Kris' code review
Imraj-Singh 21cdc4f
added a todo
Imraj-Singh 26bd325
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh 576fbbc
readme changes [ci skip]
Imraj-Singh c7a359a
Update src/torch/SIRF_torch.py [ci skip]
Imraj-Singh 3220514
removing test cases [ci skip]
Imraj-Singh 31de795
more readme changes [ci skip]
Imraj-Singh 909b1de
MR grad checks working, getting sinister segfault for obj func [ci skip]
Imraj-Singh 142a338
changing the init to allow for sirf.blah [ci skip]
Imraj-Singh 97e2439
calling for value of obj_func and removing negative [ci skip]
Imraj-Singh b2b603d
updating the README to explain dimensionality and remove negativity r…
Imraj-Singh 69470b5
Merge branch 'SyneRBI:master' into torch
Imraj-Singh a788021
trying (probably failing) to make the imports better using that __all…
Imraj-Singh d1fa40b
updating the naming to be less verbose [ci skip]
Imraj-Singh 0dd3770
moving gradcheck and use_cases so not to get circular imports... [ci …
Imraj-Singh 1049818
minimal varnet [ci skip]
Imraj-Singh a6f2282
Apply suggestions from code review [ci skip]
Imraj-Singh bef9f19
take away nonsense from __init__ and update README [ci skip]
Imraj-Singh 15179aa
Updating use cases [ci skip]
Imraj-Singh 1893fa6
requires_grad=False for tensors returned by backward, changes.md upda…
Imraj-Singh 362875b
add image to gd use_case [ci skip]
Imraj-Singh 910fa2c
misc review fixups
casperdcl 0a41e5b
Add once_differentiable decorator [ci skip]
Imraj-Singh 7d27c80
Merge branch 'SyneRBI:master' into torch
Imraj-Singh 0c91764
codacity changes [ci skip]
Imraj-Singh c0adc9c
Merge branch 'torch' of https://github.com/Imraj-Singh/SIRF into torch
Imraj-Singh a7348ce
codacy changes [ci skip]
Imraj-Singh 89204a8
run ci things
Imraj-Singh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
|
||
from sirf.torch import Operator, ObjectiveFunction, ObjectiveFunctionGradient, sirf_to_torch | ||
import sirf.STIR as pet | ||
pet.set_verbosity(1) | ||
pet.AcquisitionData.set_storage_scheme("memory") | ||
import sirf | ||
msg = sirf.STIR.MessageRedirector(info=None, warn=None, errr=None) | ||
from sirf.Utilities import examples_data_path | ||
import matplotlib.pyplot as plt | ||
import torch | ||
|
||
pet_data_path = examples_data_path('PET') | ||
pet_2d_raw_data_file = pet.existing_filepath(pet_data_path, 'thorax_single_slice/template_sinogram.hs') | ||
pet_2d_acq_data = pet.AcquisitionData(pet_2d_raw_data_file) | ||
pet_2d_init_image_file = pet.existing_filepath(pet_data_path, 'thorax_single_slice/emission.hv') | ||
pet_2d_image_data = pet.ImageData(pet_2d_init_image_file) | ||
pet_2d_acq_model = pet.AcquisitionModelUsingParallelproj() | ||
data_processor = pet.TruncateToCylinderProcessor() | ||
pet_2d_acq_model.set_image_data_processor(data_processor) | ||
pet_2d_acq_model.set_up(pet_2d_acq_data, pet_2d_image_data) | ||
sens_img = pet_2d_acq_model.backward(pet_2d_acq_data.get_uniform_copy(1.0)) | ||
inv_sens_img = sens_img.power(-1) | ||
data_processor.apply(inv_sens_img) | ||
pet_2d_acq_data = pet_2d_acq_model.forward(pet_2d_image_data) + 10.5 | ||
pet_2d_acq_model.set_background_term(pet_2d_acq_data.get_uniform_copy(10.5)) | ||
pet_2d_acq_model.set_up(pet_2d_acq_data, pet_2d_image_data) | ||
|
||
acq_data = pet_2d_acq_data | ||
acq_model = pet_2d_acq_model | ||
image_data = pet_2d_image_data | ||
obj_fun = pet.make_Poisson_loglikelihood(acq_data) | ||
obj_fun.set_acquisition_model(acq_model) | ||
obj_fun.set_up(image_data) | ||
objfuncgrad = ObjectiveFunctionGradient(obj_fun, image_data) | ||
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
cnn = torch.nn.Sequential( | ||
torch.nn.Conv2d(1, 5, 5, padding="same", bias=False), | ||
torch.nn.Conv2d(5, 5, 5, padding="same", bias=False), | ||
torch.nn.PReLU(device=dev), | ||
torch.nn.Conv2d(5, 5, 5, padding="same", bias=False), | ||
torch.nn.Conv2d(5, 5, 5, padding="same", bias=False), | ||
torch.nn.PReLU(device=dev), | ||
torch.nn.Conv2d(5, 1, 1, padding="same", bias=False), | ||
).to(dev) | ||
|
||
|
||
class UnrolledOSEMVarNet(torch.nn.Module): | ||
def __init__( | ||
self, | ||
objective_function_gradient: sirf.torch.ObjectiveFunctionGradient, | ||
inv_sens_img: torch.Tensor, | ||
convnet: torch.nn.Module, | ||
device: str, | ||
) -> None: | ||
"""Unrolled OSEM Variational Network with 2 blocks | ||
|
||
Parameters | ||
---------- | ||
objective_function : sirf.STIR objetive function | ||
(listmode) Poisson logL objective function | ||
that we use for the OSEM updates | ||
sirf_template_image : sirf.STIR.ImageData | ||
used for the conversion between torch tensors and sirf images | ||
convnet : torch.nn.Module | ||
a (convolutional) neural network that maps a minibatch tensor | ||
of shape [1,1,spatial_dimensions] onto a minibatch tensor of the same shape | ||
device : str | ||
device used for the calculations | ||
""" | ||
super().__init__() | ||
|
||
# OSEM update layer using the 1st subset of the listmode data | ||
self.objective_function_gradient = objective_function_gradient | ||
self._inv_sens_img = inv_sens_img | ||
|
||
self._convnet = convnet | ||
self._relu = torch.nn.ReLU() | ||
|
||
# trainable parameters for the fusion of the OSEM update and the CNN output in the two blocks | ||
# we start with a weight of 10 for the fusion | ||
# a good starting value depends on the scale of the input image | ||
self._fusion_weight0 = torch.nn.Parameter( | ||
10 * torch.ones(1, device=device, dtype=torch.float32) | ||
) | ||
self._fusion_weight1 = torch.nn.Parameter( | ||
10 * torch.ones(1, device=device, dtype=torch.float32) | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x1 = self._relu( | ||
self._fusion_weight0 * self._convnet(x) + self.objective_function_gradient(x)*x*self._inv_sens_img | ||
) | ||
x2 = self._relu( | ||
self._fusion_weight1 * self._convnet(x1) + self.objective_function_gradient(x1)*x1*self._inv_sens_img | ||
) | ||
|
||
return x2 | ||
|
||
inv_sens_img = torch.tensor(inv_sens_img.as_array()).unsqueeze(0).to(dev) | ||
varnet = UnrolledOSEMVarNet(objfuncgrad, inv_sens_img, cnn, dev) | ||
varnet.to(dev) | ||
|
||
# use seed for reproducibility | ||
torch.manual_seed(42) | ||
|
||
torch_image = torch.tensor(image_data.as_array()).unsqueeze(0).to(dev) | ||
# add gaussian noise to the image | ||
torch_input = torch_image + 10 * torch.randn_like(torch_image) | ||
# clip the input to be non-negative | ||
torch_input = torch.clamp(torch_input, 0).detach() | ||
# plot non noisy image and noisy image | ||
plt.imshow(torch_image.detach().cpu().numpy()[0,0]) | ||
plt.colorbar() | ||
plt.savefig("pet_image.png") | ||
plt.close() | ||
plt.imshow(torch_input.detach().cpu().numpy()[0,0]) | ||
plt.colorbar() | ||
plt.savefig("pet_noisy_image.png") | ||
plt.close() | ||
|
||
# set up the optimizer | ||
optimizer = torch.optim.Adam(varnet.parameters(), lr=1e-4) | ||
relu = torch.nn.ReLU() | ||
loss = torch.nn.MSELoss() | ||
for i in range(100): | ||
optimizer.zero_grad() | ||
loss_val = loss(relu(varnet(torch_input)),torch_image) | ||
loss_val.backward() | ||
optimizer.step() | ||
print("Iteration: ", i, "Loss: ", loss_val.item()) | ||
out = relu(varnet(torch_input)) | ||
plt.imshow(out.detach().cpu().numpy()[0,0]) | ||
plt.savefig("pet_varnet.png") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gschramm this a minimal reproduction of your varnet. I am getting minuses in my projected input still. I think this is expected as the relu only ensures non-negativity in the forward pass.
Given the loss function there can be negative elements in the gradient even if the the relu is enforcing non-negativity in the forward pass.
@KrisThielemans this may mean that we require negative values in the input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to UCL/STIR#1477
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes of course, the gradient of the obj function will generally have positive and negative elements.
Fixing the STIR issue will take me a while though. Help welcome :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in that case we can't backpropagate through the gradient of the objective... right? We should remove the varnet example or through a warning/error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a comment saying that this need resolving the STIR issue, but leave it in, as it's instructive.