-
Notifications
You must be signed in to change notification settings - Fork 34
Adding Adam optimiser #1460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Adding Adam optimiser #1460
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
af4ba7f
Adding Adam optimiser
MichaelClerx fd6afc1
Added Adam example
MichaelClerx 0695d82
Merge branch 'master' into 1105-adam
MichaelClerx fc1e402
Tweaks to Adam after review DavAug.
MichaelClerx 4f6c445
Fix to Adam test
MichaelClerx f8ebc24
Improved initialisation for irpropmin
MichaelClerx ac1e520
Tweak
MichaelClerx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
********************************* | ||
Adam (adaptive moment estimation) | ||
********************************* | ||
|
||
.. currentmodule:: pints | ||
|
||
.. autoclass:: Adam | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# | ||
# Adam optimiser. | ||
# | ||
# This file is part of PINTS (https://github.com/pints-team/pints/) which is | ||
# released under the BSD 3-clause license. See accompanying LICENSE.md for | ||
# copyright notice and full license details. | ||
# | ||
import pints | ||
|
||
import numpy as np | ||
|
||
|
||
class Adam(pints.Optimiser): | ||
""" | ||
Adam optimiser (adaptive moment estimation), as described in [1]_. | ||
|
||
This method is a variation on gradient descent that maintains two | ||
"moments", allowing it to overshoot and go against the gradient for a short | ||
time. This property can make it more robust against noisy gradients. Full | ||
pseudo-code is given in [1]_ (Algorithm 1). | ||
|
||
This implementation uses a fixed step size, set as `` min(sigma0)``. Note | ||
that the adaptivity in this method comes from the changing moments, not | ||
the step size. | ||
|
||
References | ||
---------- | ||
.. [1] Adam: A method for stochastic optimization | ||
Kingma and Ba, 2017, arxiv (version v9) | ||
https://doi.org/10.48550/arXiv.1412.6980 | ||
""" | ||
|
||
def __init__(self, x0, sigma0=0.1, boundaries=None): | ||
super().__init__(x0, sigma0, boundaries) | ||
|
||
# Set optimiser state | ||
self._running = False | ||
self._ready_for_tell = False | ||
|
||
# Best solution found | ||
self._x_best = self._x0 | ||
self._f_best = np.inf | ||
|
||
# Current point, score, and gradient | ||
self._current = self._x0 | ||
self._current_f = np.inf | ||
self._current_df = None | ||
|
||
# Proposed next point (read-only, so can be passed to user) | ||
self._proposed = self._x0 | ||
self._proposed.setflags(write=False) | ||
|
||
# Moment vectors | ||
self._m = np.zeros(self._x0.shape) | ||
self._v = np.zeros(self._x0.shape) | ||
|
||
# Exponential decay rates for the moment estimates | ||
self._b1 = 0.9 # 0 < b1 <= 1 | ||
self._b2 = 0.999 # 0 < b2 <= 1 | ||
|
||
# Step size | ||
self._alpha = np.min(self._sigma0) | ||
|
||
# Small number added to avoid divide-by-zero | ||
self._eps = 1e-8 | ||
|
||
# Powers of decay rates | ||
self._b1t = 1 | ||
self._b2t = 1 | ||
|
||
def ask(self): | ||
""" See :meth:`Optimiser.ask()`. """ | ||
|
||
# Running, and ready for tell now | ||
self._ready_for_tell = True | ||
self._running = True | ||
|
||
# Return proposed points (just the one) | ||
return [self._proposed] | ||
|
||
def f_best(self): | ||
""" See :meth:`Optimiser.f_best()`. """ | ||
return self._f_best | ||
|
||
def f_guessed(self): | ||
""" See :meth:`Optimiser.f_guessed()`. """ | ||
return self._current_f | ||
|
||
def name(self): | ||
""" See :meth:`Optimiser.name()`. """ | ||
return 'Adam' | ||
|
||
def needs_sensitivities(self): | ||
""" See :meth:`Optimiser.needs_sensitivities()`. """ | ||
return True | ||
|
||
def n_hyper_parameters(self): | ||
""" See :meth:`pints.TunableMethod.n_hyper_parameters()`. """ | ||
return 0 | ||
|
||
def running(self): | ||
""" See :meth:`Optimiser.running()`. """ | ||
return self._running | ||
|
||
def tell(self, reply): | ||
""" See :meth:`Optimiser.tell()`. """ | ||
|
||
# Check ask-tell pattern | ||
if not self._ready_for_tell: | ||
raise Exception('ask() not called before tell()') | ||
self._ready_for_tell = False | ||
|
||
# Unpack reply | ||
fx, dfx = reply[0] | ||
|
||
# Update current point | ||
self._current = self._proposed | ||
self._current_f = fx | ||
self._current_df = dfx | ||
|
||
# Update bx^t | ||
self._b1t *= self._b1 | ||
self._b2t *= self._b2 | ||
|
||
# "Update biased first moment estimate" | ||
self._m = self._b1 * self._m + (1 - self._b1) * dfx | ||
|
||
# "Update biased secon raw moment estimate" | ||
self._v = self._b2 * self._v + (1 - self._b2) * dfx**2 | ||
|
||
# "Compute bias-corrected first moment estimate" | ||
m = self._m / (1 - self._b1t) | ||
|
||
# "Compute bias-corrected second raw moment estimate" | ||
v = self._v / (1 - self._b2t) | ||
|
||
# Take step | ||
self._proposed = ( | ||
self._current - self._alpha * m / (np.sqrt(v) + self._eps)) | ||
|
||
# Update x_best and f_best | ||
if self._f_best > fx: | ||
self._f_best = fx | ||
self._x_best = self._current | ||
|
||
def x_best(self): | ||
""" See :meth:`Optimiser.x_best()`. """ | ||
return self._x_best | ||
|
||
def x_guessed(self): | ||
""" See :meth:`Optimiser.x_guessed()`. """ | ||
return self._current | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# Tests the API of the Adam optimiser. | ||
# | ||
MichaelClerx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# This file is part of PINTS (https://github.com/pints-team/pints/) which is | ||
# released under the BSD 3-clause license. See accompanying LICENSE.md for | ||
# copyright notice and full license details. | ||
# | ||
import unittest | ||
import numpy as np | ||
|
||
import pints | ||
import pints.toy | ||
|
||
from shared import StreamCapture | ||
|
||
|
||
debug = False | ||
method = pints.Adam | ||
|
||
|
||
class TestAdam(unittest.TestCase): | ||
""" | ||
Tests the API of the Adam optimiser. | ||
""" | ||
def setUp(self): | ||
""" Called before every test """ | ||
np.random.seed(1) | ||
|
||
def problem(self): | ||
""" Returns a test problem, starting point, and sigma. """ | ||
r = pints.toy.ParabolicError() | ||
x = [0.1, 0.1] | ||
s = 0.1 | ||
return r, x, s | ||
|
||
def test_simple(self): | ||
# Runs an optimisation | ||
r, x, s = self.problem() | ||
|
||
opt = pints.OptimisationController(r, x, sigma0=s, method=method) | ||
opt.set_log_to_screen(debug) | ||
found_parameters, found_solution = opt.run() | ||
|
||
# True solution is (0, 0) with error 0 | ||
self.assertTrue(found_solution < 1e-9) | ||
self.assertLess(abs(found_parameters[0]), 1e-8) | ||
self.assertLess(abs(found_parameters[1]), 1e-8) | ||
|
||
def test_ask_tell(self): | ||
# Tests ask-and-tell related error handling. | ||
r, x, s = self.problem() | ||
opt = method(x) | ||
|
||
# Stop called when not running | ||
self.assertFalse(opt.running()) | ||
self.assertFalse(opt.stop()) | ||
|
||
# Best position and score called before run | ||
self.assertEqual(list(opt.x_best()), list(x)) | ||
self.assertEqual(list(opt.x_guessed()), list(x)) | ||
self.assertEqual(opt.f_best(), float('inf')) | ||
self.assertEqual(opt.f_guessed(), float('inf')) | ||
|
||
# Tell before ask | ||
self.assertRaisesRegex( | ||
Exception, r'ask\(\) not called before tell\(\)', opt.tell, 5) | ||
|
||
# Ask | ||
opt.ask() | ||
|
||
# Now we should be running | ||
self.assertTrue(opt.running()) | ||
|
||
def test_hyper_parameter_interface(self): | ||
# Tests the hyper parameter interface for this optimiser. | ||
opt = method([0]) | ||
self.assertEqual(opt.n_hyper_parameters(), 0) | ||
|
||
def test_logging(self): | ||
|
||
# Test with logpdf | ||
r, x, s = self.problem() | ||
opt = pints.OptimisationController(r, x, s, method=method) | ||
opt.set_log_to_screen(True) | ||
opt.set_max_unchanged_iterations(None) | ||
opt.set_max_iterations(2) | ||
with StreamCapture() as c: | ||
opt.run() | ||
lines = c.text().splitlines() | ||
self.assertEqual(lines[0], 'Minimising error measure') | ||
self.assertEqual( | ||
lines[1], 'Using Adam') | ||
self.assertEqual(lines[2], 'Running in sequential mode.') | ||
self.assertEqual( | ||
lines[3], | ||
'Iter. Eval. Best Current Time m:s') | ||
self.assertEqual( | ||
lines[4][:-3], | ||
'0 1 0.02 0.02 0:0') | ||
self.assertEqual( | ||
lines[5][:-3], | ||
'1 2 5e-17 5e-17 0:0') | ||
|
||
def test_name(self): | ||
# Test the name() method. | ||
opt = method(np.array([0])) | ||
self.assertEqual(opt.name(), 'Adam') | ||
self.assertTrue(opt.needs_sensitivities()) | ||
|
||
|
||
if __name__ == '__main__': | ||
print('Add -v for more debug output') | ||
import sys | ||
if '-v' in sys.argv: | ||
debug = True | ||
unittest.main() | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.