Skip to content

Commit 21c5707

Browse files
committed
Send the data to optimizer without a copy works for dense layers
1 parent 38896cc commit 21c5707

File tree

2 files changed

+81
-81
lines changed

2 files changed

+81
-81
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -694,22 +694,22 @@ module subroutine update(self, optimizer, batch_size)
694694
end do
695695
#endif
696696

697-
!params = self % get_params()
698-
!call self % optimizer % minimize(params, self % get_gradients() / batch_size_)
699-
!call self % set_params(params)
700-
701697
do n = 2, size(self % layers)
702698
select type(this_layer => self % layers(n) % p)
703699
type is(dense_layer)
704700
call this_layer % get_params_ptr(weights, biases)
705701
call this_layer % get_gradients_ptr(dw, db)
706702
call self % optimizer % minimize(weights, dw / batch_size_)
707703
call self % optimizer % minimize(biases, db / batch_size_)
708-
!call this_layer % set_params(weights, biases)
704+
type is(locally_connected1d_layer)
705+
!TODO
706+
type is(conv1d_layer)
707+
!TODO
708+
type is(conv2d_layer)
709+
!TODO
709710
end select
710711
end do
711712

712-
713713
! Flush network gradients to zero.
714714
do n = 2, size(self % layers)
715715
select type(this_layer => self % layers(n) % p)

src/nf/nf_optimizers.f90

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,32 @@ pure subroutine minimize_sgd_1d(self, param, gradient)
155155
end subroutine minimize_sgd_1d
156156

157157

158+
pure subroutine minimize_sgd_2d(self, param, gradient)
159+
!! Concrete implementation of a stochastic gradient descent optimizer
160+
!! update rule for 2D arrays.
161+
class(sgd), intent(inout) :: self
162+
real, intent(inout) :: param(:,:)
163+
real, intent(in) :: gradient(:,:)
164+
165+
if (self % momentum > 0) then
166+
! Apply momentum update
167+
self % velocity = self % momentum * self % velocity &
168+
- self % learning_rate * reshape(gradient, [size(gradient)])
169+
if (self % nesterov) then
170+
! Apply Nesterov update
171+
param = param + reshape(self % momentum * self % velocity &
172+
- self % learning_rate * reshape(gradient, [size(gradient)]), shape(param))
173+
else
174+
param = param + reshape(self % velocity, shape(param))
175+
end if
176+
else
177+
! Apply regular update
178+
param = param - self % learning_rate * gradient
179+
end if
180+
181+
end subroutine minimize_sgd_2d
182+
183+
158184
impure elemental subroutine init_rmsprop(self, num_params)
159185
class(rmsprop), intent(inout) :: self
160186
integer, intent(in) :: num_params
@@ -182,6 +208,23 @@ pure subroutine minimize_rmsprop_1d(self, param, gradient)
182208
end subroutine minimize_rmsprop_1d
183209

184210

211+
pure subroutine minimize_rmsprop_2d(self, param, gradient)
212+
!! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
213+
class(rmsprop), intent(inout) :: self
214+
real, intent(inout) :: param(:,:)
215+
real, intent(in) :: gradient(:,:)
216+
217+
! Compute the RMS of the gradient using the RMSProp rule
218+
self % rms_gradient = self % decay_rate * self % rms_gradient &
219+
+ (1 - self % decay_rate) * reshape(gradient, [size(gradient)])**2
220+
221+
! Update the network parameters based on the new RMS of the gradient
222+
param = param - self % learning_rate &
223+
/ sqrt(reshape(self % rms_gradient, shape(param)) + self % epsilon) * gradient
224+
225+
end subroutine minimize_rmsprop_2d
226+
227+
185228
impure elemental subroutine init_adam(self, num_params)
186229
class(adam), intent(inout) :: self
187230
integer, intent(in) :: num_params
@@ -224,6 +267,37 @@ pure subroutine minimize_adam_1d(self, param, gradient)
224267
end subroutine minimize_adam_1d
225268

226269

270+
pure subroutine minimize_adam_2d(self, param, gradient)
271+
!! Concrete implementation of an Adam optimizer update rule for 2D arrays.
272+
class(adam), intent(inout) :: self
273+
real, intent(inout) :: param(:,:)
274+
real, intent(in) :: gradient(:,:)
275+
276+
self % t = self % t + 1
277+
278+
! If weight_decay_l2 > 0, use L2 regularization;
279+
! otherwise, default to regular Adam.
280+
associate(g => reshape(gradient, [size(gradient)]) + self % weight_decay_l2 * reshape(param, [size(param)]))
281+
self % m = self % beta1 * self % m + (1 - self % beta1) * g
282+
self % v = self % beta2 * self % v + (1 - self % beta2) * g**2
283+
end associate
284+
285+
! Compute bias-corrected first and second moment estimates.
286+
associate( &
287+
m_hat => self % m / (1 - self % beta1**self % t), &
288+
v_hat => self % v / (1 - self % beta2**self % t) &
289+
)
290+
291+
! Update parameters.
292+
param = param &
293+
- self % learning_rate * reshape(m_hat / (sqrt(v_hat) + self % epsilon), shape(param)) &
294+
- self % learning_rate * self % weight_decay_decoupled * param
295+
296+
end associate
297+
298+
end subroutine minimize_adam_2d
299+
300+
227301
impure elemental subroutine init_adagrad(self, num_params)
228302
class(adagrad), intent(inout) :: self
229303
integer, intent(in) :: num_params
@@ -262,80 +336,6 @@ pure subroutine minimize_adagrad_1d(self, param, gradient)
262336
end subroutine minimize_adagrad_1d
263337

264338

265-
pure subroutine minimize_sgd_2d(self, param, gradient)
266-
!! Concrete implementation of a stochastic gradient descent optimizer
267-
!! update rule for 2D arrays.
268-
class(sgd), intent(inout) :: self
269-
real, intent(inout) :: param(:,:)
270-
real, intent(in) :: gradient(:,:)
271-
272-
if (self % momentum > 0) then
273-
! Apply momentum update
274-
self % velocity = self % momentum * self % velocity &
275-
- self % learning_rate * reshape(gradient, [size(gradient)])
276-
if (self % nesterov) then
277-
! Apply Nesterov update
278-
param = param + reshape(self % momentum * self % velocity &
279-
- self % learning_rate * reshape(gradient, [size(gradient)]), shape(param))
280-
else
281-
param = param + reshape(self % velocity, shape(param))
282-
end if
283-
else
284-
! Apply regular update
285-
param = param - self % learning_rate * gradient
286-
end if
287-
288-
end subroutine minimize_sgd_2d
289-
290-
291-
pure subroutine minimize_rmsprop_2d(self, param, gradient)
292-
!! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
293-
class(rmsprop), intent(inout) :: self
294-
real, intent(inout) :: param(:,:)
295-
real, intent(in) :: gradient(:,:)
296-
297-
! Compute the RMS of the gradient using the RMSProp rule
298-
self % rms_gradient = self % decay_rate * self % rms_gradient &
299-
+ (1 - self % decay_rate) * reshape(gradient, [size(gradient)])**2
300-
301-
! Update the network parameters based on the new RMS of the gradient
302-
param = param - self % learning_rate &
303-
/ sqrt(reshape(self % rms_gradient, shape(param)) + self % epsilon) * gradient
304-
305-
end subroutine minimize_rmsprop_2d
306-
307-
308-
pure subroutine minimize_adam_2d(self, param, gradient)
309-
!! Concrete implementation of an Adam optimizer update rule for 2D arrays.
310-
class(adam), intent(inout) :: self
311-
real, intent(inout) :: param(:,:)
312-
real, intent(in) :: gradient(:,:)
313-
314-
self % t = self % t + 1
315-
316-
! If weight_decay_l2 > 0, use L2 regularization;
317-
! otherwise, default to regular Adam.
318-
associate(g => reshape(gradient, [size(gradient)]) + self % weight_decay_l2 * reshape(param, [size(param)]))
319-
self % m = self % beta1 * self % m + (1 - self % beta1) * g
320-
self % v = self % beta2 * self % v + (1 - self % beta2) * g**2
321-
end associate
322-
323-
! Compute bias-corrected first and second moment estimates.
324-
associate( &
325-
m_hat => self % m / (1 - self % beta1**self % t), &
326-
v_hat => self % v / (1 - self % beta2**self % t) &
327-
)
328-
329-
! Update parameters.
330-
param = param &
331-
- self % learning_rate * reshape(m_hat / (sqrt(v_hat) + self % epsilon), shape(param)) &
332-
- self % learning_rate * self % weight_decay_decoupled * param
333-
334-
end associate
335-
336-
end subroutine minimize_adam_2d
337-
338-
339339
pure subroutine minimize_adagrad_2d(self, param, gradient)
340340
!! Concrete implementation of an Adagrad optimizer update rule for 2D arrays.
341341
class(adagrad), intent(inout) :: self
@@ -363,4 +363,4 @@ pure subroutine minimize_adagrad_2d(self, param, gradient)
363363

364364
end subroutine minimize_adagrad_2d
365365

366-
end module nf_optimizers
366+
end module nf_optimizers

0 commit comments

Comments
 (0)