Skip to content

Feature/upgrade optuna 4.2.1 #3046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

__version__ = "1.4.0.dev0"

from .plugin import OptunaSweeperSearchPathPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@
BaseDistribution,
CategoricalChoiceType,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
FloatDistribution,
IntDistribution,
)
from optuna.trial import Trial

Expand All @@ -61,18 +58,16 @@ def create_optuna_distribution_from_config(
if param.type == DistributionType.int:
assert param.low is not None
assert param.high is not None
if param.log:
return IntLogUniformDistribution(int(param.low), int(param.high))
step = int(param.step) if param.step is not None else 1
return IntUniformDistribution(int(param.low), int(param.high), step=step)
return IntDistribution(
low=int(param.low), high=int(param.high), step=step, log=param.log
)
if param.type == DistributionType.float:
assert param.low is not None
assert param.high is not None
if param.log:
return LogUniformDistribution(param.low, param.high)
if param.step is not None:
return DiscreteUniformDistribution(param.low, param.high, param.step)
return UniformDistribution(param.low, param.high)
return FloatDistribution(
low=param.low, high=param.high, step=param.step, log=param.log
)
raise NotImplementedError(f"{param.type} is not supported by Optuna sweeper.")


Expand Down Expand Up @@ -107,23 +102,22 @@ def create_optuna_distribution_from_override(override: Override) -> Any:
or isinstance(value.stop, float)
or isinstance(value.step, float)
):
return DiscreteUniformDistribution(value.start, value.stop, value.step)
return IntUniformDistribution(
int(value.start), int(value.stop), step=int(value.step)
return FloatDistribution(low=value.start, high=value.stop, step=value.step)
return IntDistribution(
low=int(value.start), high=int(value.stop), step=int(value.step)
)

if override.is_interval_sweep():
assert isinstance(value, IntervalSweep)
assert value.start is not None
assert value.end is not None
if "log" in value.tags:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntLogUniformDistribution(int(value.start), int(value.end))
return LogUniformDistribution(value.start, value.end)
else:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntUniformDistribution(value.start, value.end)
return UniformDistribution(value.start, value.end)
if isinstance(value.start, int) and isinstance(value.end, int):
return IntDistribution(
low=value.start, high=value.end, log="log" in value.tags
)
return FloatDistribution(
low=value.start, high=value.end, log="log" in value.tags
)

raise NotImplementedError(f"{override} is not supported by Optuna sweeper.")

Expand Down Expand Up @@ -237,7 +231,30 @@ def _configure_trials(
for trial in trials:
for param_name, distribution in search_space_distributions.items():
assert type(param_name) is str
trial._suggest(param_name, distribution)
# Replace _suggest with public API methods
if isinstance(distribution, CategoricalDistribution):
trial.suggest_categorical(param_name, distribution.choices)
elif isinstance(distribution, IntDistribution):
trial.suggest_int(
param_name,
distribution.low,
distribution.high,
step=distribution.step,
log=distribution.log,
)
elif isinstance(distribution, FloatDistribution):
trial.suggest_float(
param_name,
distribution.low,
distribution.high,
step=distribution.step,
log=distribution.log,
)
else:
raise NotImplementedError(
f"Distribution {distribution} not supported"
)

for param_name, value in fixed_params.items():
trial.set_user_attr(param_name, value)

Expand Down Expand Up @@ -266,15 +283,20 @@ def _parse_sweeper_params_config(self) -> List[str]:
def _to_grid_sampler_choices(self, distribution: BaseDistribution) -> Any:
if isinstance(distribution, CategoricalDistribution):
return distribution.choices
elif isinstance(distribution, IntUniformDistribution):
elif isinstance(distribution, IntDistribution):
assert (
distribution.step is not None
), "`step` of IntUniformDistribution must be a positive integer."
n_items = (distribution.high - distribution.low) // distribution.step
), "`step` of IntDistribution must be a positive integer."
n_items = (distribution.high - distribution.low) // distribution.step + 1
return [distribution.low + i * distribution.step for i in range(n_items)]
elif (
isinstance(distribution, FloatDistribution)
and distribution.step is not None
):
n_items = (
int((distribution.high - distribution.low) / distribution.step) + 1
)
return [distribution.low + i * distribution.step for i in range(n_items)]
elif isinstance(distribution, DiscreteUniformDistribution):
n_items = int((distribution.high - distribution.low) // distribution.q)
return [distribution.low + i * distribution.q for i in range(n_items)]
else:
raise ValueError("GridSampler only supports discrete distributions.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class TPESamplerConfig(SamplerConfig):
n_startup_trials: int = 10
n_ei_candidates: int = 24
multivariate: bool = False
warn_independent_sampling: bool = True


@dataclass
Expand All @@ -75,12 +74,12 @@ class CmaEsSamplerConfig(SamplerConfig):
x0: Optional[Dict[str, Any]] = None
sigma0: Optional[float] = None
independent_sampler: Optional[Any] = None
warn_independent_sampling: bool = True
consider_pruned_trials: bool = False
restart_strategy: Optional[Any] = None
inc_popsize: int = 2
use_separable_cma: bool = False
source_trials: Optional[Any] = None
with_margin: bool = False


@dataclass
Expand Down Expand Up @@ -114,6 +113,7 @@ class MOTPESamplerConfig(SamplerConfig):
consider_endpoints: bool = False
n_startup_trials: int = 10
n_ehvi_candidates: int = 24
constraints_func: Optional[Any] = None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from hydra.plugins.search_path_plugin import SearchPathPlugin
from hydra.core.config_search_path import ConfigSearchPath


class OptunaSweeperSearchPathPlugin(SearchPathPlugin):
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
search_path.append("optuna_sweeper", "pkg://hydra_optuna_sweeper.conf")
1 change: 1 addition & 0 deletions plugins/hydra_optuna_sweeper/news/optuna_421.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Upgraded hydra_optuna_sweeper to use Optuna 4.2.1, replacing deprecated distribution classes with modern equivalents and updating the API usage.
9 changes: 7 additions & 2 deletions plugins/hydra_optuna_sweeper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@
],
install_requires=[
"hydra-core>=1.1.0.dev7",
"optuna>=2.10.0,<3.0.0",
"sqlalchemy~=1.3.0", # TODO: Unpin when upgrading to optuna v3.0
"optuna>=4.2.1",
"sqlalchemy>=2.0.0", # Updated for optuna v4.2.1 compatibility
],
include_package_data=True,
entry_points={
"hydra_plugins": [
"optuna_sweeper = hydra_plugins.hydra_optuna_sweeper",
],
},
)
63 changes: 38 additions & 25 deletions plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
from optuna.distributions import (
BaseDistribution,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
FloatDistribution,
IntDistribution,
)
from optuna.samplers import RandomSampler
from pytest import mark, warns
Expand All @@ -43,13 +40,29 @@ def test_discovery() -> None:


def check_distribution(expected: BaseDistribution, actual: BaseDistribution) -> None:
if not isinstance(expected, CategoricalDistribution):
assert expected == actual
if isinstance(expected, CategoricalDistribution):
assert isinstance(actual, CategoricalDistribution)
# shuffle() will randomize the order of items in choices.
assert set(expected.choices) == set(actual.choices)
return

assert isinstance(actual, CategoricalDistribution)
# shuffle() will randomize the order of items in choices.
assert set(expected.choices) == set(actual.choices)

if isinstance(expected, IntDistribution):
assert isinstance(actual, IntDistribution)
assert expected.low == actual.low
assert expected.high == actual.high
assert expected.step == actual.step
assert expected.log == actual.log
return

if isinstance(expected, FloatDistribution):
assert isinstance(actual, FloatDistribution)
assert expected.low == actual.low
assert expected.high == actual.high
assert expected.log == actual.log
return


assert expected == actual


@mark.parametrize(
Expand All @@ -59,24 +72,24 @@ def check_distribution(expected: BaseDistribution, actual: BaseDistribution) ->
{"type": "categorical", "choices": [1, 2, 3]},
CategoricalDistribution([1, 2, 3]),
),
({"type": "int", "low": 0, "high": 10}, IntUniformDistribution(0, 10)),
({"type": "int", "low": 0, "high": 10}, IntDistribution(low=0, high=10, step=1)),
(
{"type": "int", "low": 0, "high": 10, "step": 2},
IntUniformDistribution(0, 10, step=2),
IntDistribution(low=0, high=10, step=2),
),
({"type": "int", "low": 0, "high": 5}, IntUniformDistribution(0, 5)),
({"type": "int", "low": 0, "high": 5}, IntDistribution(low=0, high=5, step=1)),
(
{"type": "int", "low": 1, "high": 100, "log": True},
IntLogUniformDistribution(1, 100),
IntDistribution(low=1, high=100, log=True),
),
({"type": "float", "low": 0, "high": 1}, UniformDistribution(0, 1)),
({"type": "float", "low": 0, "high": 1}, FloatDistribution(low=0, high=1)),
(
{"type": "float", "low": 0, "high": 10, "step": 2},
DiscreteUniformDistribution(0, 10, 2),
FloatDistribution(low=0, high=10, step=2),
),
(
{"type": "float", "low": 1, "high": 100, "log": True},
LogUniformDistribution(1, 100),
FloatDistribution(low=1, high=100, log=True),
),
],
)
Expand All @@ -92,12 +105,12 @@ def test_create_optuna_distribution_from_config(input: Any, expected: Any) -> No
("key=choice(true, false)", CategoricalDistribution([True, False])),
("key=choice('hello', 'world')", CategoricalDistribution(["hello", "world"])),
("key=shuffle(range(1,3))", CategoricalDistribution((1, 2))),
("key=range(1,3)", IntUniformDistribution(1, 3)),
("key=interval(1, 5)", UniformDistribution(1, 5)),
("key=int(interval(1, 5))", IntUniformDistribution(1, 5)),
("key=tag(log, interval(1, 5))", LogUniformDistribution(1, 5)),
("key=tag(log, int(interval(1, 5)))", IntLogUniformDistribution(1, 5)),
("key=range(0.5, 5.5, step=1)", DiscreteUniformDistribution(0.5, 5.5, 1)),
("key=range(1,3)", IntDistribution(low=1, high=3, step=1)),
("key=interval(1, 5)", FloatDistribution(low=1, high=5)),
("key=int(interval(1, 5))", IntDistribution(low=1, high=5, step=1)),
("key=tag(log, interval(1, 5))", FloatDistribution(low=1, high=5, log=True)),
("key=tag(log, int(interval(1, 5)))", IntDistribution(low=1, high=5, log=True)),
("key=range(0.5, 5.5, step=1)", FloatDistribution(low=0.5, high=5.5, step=1)),
],
)
def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> None:
Expand All @@ -121,7 +134,7 @@ def test_create_optuna_distribution_from_override(input: Any, expected: Any) ->
(
{
"key1": CategoricalDistribution([1, 2]),
"key3": IntUniformDistribution(1, 3),
"key3": IntDistribution(low=1, high=3, step=1),
},
{"key2": "5"},
),
Expand Down