Skip to content

AlphaZero not working on custom env: solving algebraic equations #329

Closed
@Khev

Description

@Khev

Hi there,

CONTEXT

Many thanks for this great repo. I'm interested in adding an "equation solving" env, in which an RL agent must learn to solve equations like

ax + b = 0
a/x + b = 0
(ax+b)/(cx+d) + e = 0

By taking actions like (operation, term), where operation \in (add, sub, mul, div) and term \in (a,b,ax, ...) = terms in the equation. So sovling the equation ax +b = 0 requires two actions (sub, b), (div, a). I have solved most equations using ppo in about 10^3 or 10^4 steps.

ISSUE

I am trying to use your impementation of AlphaZero to solve the envs. But its struggling to solve even basic equation like

x +b = 0
ax+b = 0

The learning curves are below. You can see its getting the right solution pretty often, but its failing to learn from these experiences. The policy fails to converge. I'm wondering if there's something wrong with my setup. Below are my config files

Any insight would be most welcome.

One other question: would you be interested in adding this custom env to the repo? Its a challenging benchmark for RL based symbolic mathematics. It would be a good contribution for the community.

Image Image

FILES

from easydict import EasyDict

# ==============================================================
# Frequently changed config specified by the user (lightweight settings)
# ==============================================================
collector_env_num = 4         # Number of parallel environments for data collection
n_episode = 4                 # Number of episodes per training iteration
evaluator_env_num = 1         # Number of evaluator environments
num_simulations = 50          # MCTS simulations per move (try increasing if needed)
update_per_collect = 100      # Number of gradient updates per data collection cycle
batch_size = 32               # Mini-batch size for training
max_env_step = int(1e5)       # Maximum total environment steps for a quick run
model_path = None
mcts_ctree = False

# ==============================================================
# Configurations for singleEqn_env (lightweight version)
# ==============================================================
singleEqn_alphazero_config = dict(
    exp_name='data_alphazero/singleEqn/a/x+b/',
    env=dict(
        battle_mode='play_with_bot_mode',
        battle_mode_in_simulation_env='self_play_mode',  # For simulation during MCTS
        channel_last=False,
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
        n_evaluator_episode=evaluator_env_num,
        manager=dict(shared_memory=False),
        agent_vs_human=False,
        prob_random_agent=0,
        prob_expert_agent=0,
        prob_random_action_in_bot=0,
        scale=True,
        render_mode=None,
        replay_path=None,
        alphazero_mcts_ctree=mcts_ctree,
    ),
    policy=dict(
        mcts_ctree=mcts_ctree,
        simulation_env_id='singleEqn_env',  # Must match the registered name of your environment
        model=dict(
            type='AlphaZeroMLPModel',
            import_names=['zoo.custom_envs.equation_solver.my_alphazero_mlp_model'],
            observation_shape=(41,),        # Flat vector of length 41
            action_space_size=50,             # Matches your environment's action_dim
            hidden_sizes=[64, 64],          # MLP hidden layer sizes
        ),
        cuda=True,
        env_type='not_board_games',
        action_type='varied_action_space',
        update_per_collect=update_per_collect,
        batch_size=batch_size,
        optim_type='Adam',
        lr_piecewise_constant_decay=False,
        # learning_rate=0.003,
        learning_rate=3e-4,
        grad_clip_value=0.5,
        value_weight=1.0,
        entropy_weight=0.0,
        n_episode=n_episode,
        eval_freq=int(2e3),
        mcts=dict(num_simulations=num_simulations),
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
        other=dict(
            replay_buffer=dict(
                type='advanced',              # Use advanced (or prioritized) replay buffer
                replay_buffer_size=10000,       # Set a smaller buffer for lightweight runs
                sample_min_limit_ratio=0.25,      # Allow sampling even if only 50% of batch size is available.
                alpha=0.6,
                beta=0.4,
                anneal_step=100000,
                enable_track_used_data=False,
                deepcopy=False,
                save_episode=False,
            )
        ),
    ),
)
singleEqn_alphazero_config = EasyDict(singleEqn_alphazero_config)
main_config = singleEqn_alphazero_config

singleEqn_alphazero_create_config = dict(
    env=dict(
        type='singleEqn_env',
        import_names=['zoo.custom_envs.equation_solver.env_single_eqn'],  # Adjust this path if needed
    ),
    env_manager=dict(type='subprocess'),
    policy=dict(
        type='MyAlphaZeroPolicy',  # Your custom policy subclass
        import_names=['zoo.custom_envs.equation_solver.my_alphazero_policy'],
    ),
    collector=dict(
        type='episode_alphazero',
        import_names=['lzero.worker.alphazero_collector'],
    ),
    evaluator=dict(
        type='alphazero',
        import_names=['lzero.worker.alphazero_evaluator'],
    )
)
singleEqn_alphazero_create_config = EasyDict(singleEqn_alphazero_create_config)
create_config = singleEqn_alphazero_create_config

if __name__ == '__main__':
    from lzero.entry import train_alphazero
    # Merge the environment configuration into the policy config.
    main_config.policy.env = main_config.env
    train_alphazero([main_config, create_config], seed=0, model_path=model_path, max_env_step=max_env_step)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions