@@ -234,19 +234,19 @@ def SSE(model):
234
234
expr = sum ((y - y_hat ) ** 2 for y , y_hat in model .experiment_outputs .items ())
235
235
return expr
236
236
237
- def regularize_term (model , FIM , theta_ref ):
237
+ def regularize_term (model , prior_FIM , theta_ref ):
238
238
"""
239
239
Regularization term for the objective function, which is used to penalize deviation from a
240
240
reference theta
241
- (theta - theta_ref).transpose() * FIM * (theta - theta_ref)
241
+ (theta - theta_ref).transpose() * prior_FIM * (theta - theta_ref)
242
242
243
243
theta_ref: Reference parameter value, element of matrix
244
244
FIM: Fisher Information Matrix, matrix
245
245
theta: Parameter value, matrix
246
246
247
247
Added to SSE objective function
248
248
"""
249
- expr = ((theta - theta_ref ).transpose () * FIM * (theta - theta_ref ) for theta in model .unknown_parameters .items ())
249
+ expr = ((theta - theta_ref ).transpose () * prior_FIM * (theta - theta_ref ) for theta in model .unknown_parameters .items ())
250
250
return expr
251
251
252
252
@@ -285,7 +285,7 @@ def __init__(
285
285
self ,
286
286
experiment_list ,
287
287
obj_function = None ,
288
- FIM = None ,
288
+ prior_FIM = None ,
289
289
theta_ref = None ,
290
290
tee = False ,
291
291
diagnostic_mode = False ,
@@ -444,10 +444,18 @@ def _create_parmest_model(self, experiment_number):
444
444
# TODO, this needs to be turned into an enum class of options that still support
445
445
# custom functions
446
446
if self .obj_function == 'SSE' :
447
- second_stage_rule = SSE
447
+
448
448
if self .FIM and self .theta_ref is not None :
449
449
# Regularize the objective function
450
- second_stage_rule = SSE + regularize_term
450
+ second_stage_rule = SSE + regularize_term (prior_FIM = self .prior_FIM , theta_ref = self .theta_ref )
451
+ elif self .FIM :
452
+ theta_ref = model .unknown_parameters .values ()
453
+ second_stage_rule = SSE + regularize_term (prior_FIM = self .prior_FIM , theta_ref = self .theta_ref )
454
+
455
+ else :
456
+ # Sum of squared errors
457
+ second_stage_rule = SSE
458
+
451
459
452
460
453
461
else :
0 commit comments