Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(rjy): add HAPPO algorithm #717

Merged
merged 28 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
af96dc8
model(rjy): add vac model for HAPPO
Aug 30, 2023
a596c27
test(rjy): polish havac and add test
Sep 15, 2023
d282845
polish(rjy): fix conflict
Sep 15, 2023
43bb9e8
polish(rjy): add hidden_state for ac
Sep 21, 2023
b04ea13
feature(rjy): change the havac to multiagent model
Oct 10, 2023
f5648d0
feature(rjy): add happo forward_learn
Oct 10, 2023
42f4027
Merge branch 'main' into rjy-happo-model
Oct 11, 2023
42faae6
feature(rjy): modify the happo_data
Oct 20, 2023
3319a55
test(rjy): add happo data test
Oct 20, 2023
e3fdb80
feature(rjy): add HAPPO policy
Oct 26, 2023
8d4791d
feature(rjy): try to fit mu-mujoco
Oct 30, 2023
850f831
polish(rjy): Change code to adapt to mujoco
Oct 31, 2023
8e281dc
fix(rjy): fix the distribution in ppo update
Oct 31, 2023
f828553
fix(rjy): fix the happo+mujoco
Nov 3, 2023
70da407
config(rjy): add walker+happo config
Nov 9, 2023
23d1ddb
polish(rjy): separate actors and critics
Dec 27, 2023
ca3daff
polish(rjy): polish according to comments
Dec 27, 2023
e7277b8
polish(rjy): fix the pipeline
Dec 28, 2023
910a8f4
Merge branch 'main' into rjy-happo-model
Dec 29, 2023
b03390b
polish(rjy): fix the style
Dec 29, 2023
d5ace8e
polish(rjy): polish according to comments
Dec 29, 2023
78bffa7
polish(rjy): fix style
Dec 29, 2023
84028d8
polish(rjy): fix style
Dec 29, 2023
b6e7239
polish(rjy): fix style
Dec 29, 2023
8fc9517
polish(rjy): seperate the happo model
Jan 5, 2024
48dcd94
fix(rjy): fix happo model style
Jan 5, 2024
a1bf76f
polish(rjy): polish happo policy comments
Jan 10, 2024
e7e9662
polish(rjy): polish happo comments
Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc import serial_pipeline_pc
from .serial_entry_happo import serial_entry_happo
116 changes: 116 additions & 0 deletions ding/entry/serial_entry_happo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Union, Optional, List, Any, Tuple
import os
import torch
from ditk import logging
from functools import partial
from tensorboardX import SummaryWriter
from copy import deepcopy

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed


def serial_entry_happo(
nighood marked this conversation as resolved.
Show resolved Hide resolved
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry on-policy RL.
nighood marked this conversation as resolved.
Show resolved Hide resolved
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, None, policy.command_mode
)

# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')

while True:
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Learn policy from collected data
learner.train(new_data, collector.envstep)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

# Learner's after_run hook.
learner.call_hook('after_run')
import time
import pickle
import numpy as np
with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
eval_value_raw = eval_info['eval_episode_return']
final_data = {
'stop': stop,
'env_step': collector.envstep,
'train_iter': learner.train_iter,
'eval_value': np.mean(eval_value_raw),
'eval_value_raw': eval_value_raw,
'finish_time': time.ctime(),
}
pickle.dump(final_data, f)
return policy

1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .bcq import BCQ
from .edac import EDAC
from .havac import HAVAC
Loading