diff --git a/src/ReinforcementLearningCore/src/policies/agent/base.jl b/src/ReinforcementLearningCore/src/policies/agent/base.jl index 7654de4d6..4af02db61 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/base.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/base.jl @@ -51,10 +51,6 @@ function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajecto @functor Agent (policy,) -function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv) - push!(agent, state(env)) -end - # !!! TODO: In async scenarios, parameters of the policy may still be updating # (partially), which will result to incorrect action. This should be addressed # in Oolong.jl with a wrapper @@ -64,11 +60,8 @@ function RLBase.plan!(agent::Agent{P,T,C}, env::AbstractEnv) where {P,T,C} action end -# Multiagent Version -function RLBase.plan!(agent::Agent{P,T,C}, env::E, p::Symbol) where {P,T,C,E<:AbstractEnv} - action = RLBase.plan!(agent.policy, env, p) - push!(agent.trajectory, agent.cache, action) - action +function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv) + push!(agent, state(env)) end function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E) where {P,T,C,E<:AbstractEnv} @@ -79,11 +72,26 @@ function Base.push!(agent::Agent, ::PostExperimentStage, env::E) where {E<:Abstr RLBase.reset!(agent.cache) end -function Base.push!(agent::Agent, ::PostExperimentStage, env::E, player::Symbol) where {E<:AbstractEnv} - RLBase.reset!(agent.cache) -end - function Base.push!(agent::Agent{P,T,C}, state::S) where {P,T,C,S} push!(agent.cache, state) end +# Multiagent Version +function RLBase.plan!(agent::Agent{P,T,C}, env::E, p::Symbol) where {P,T,C,E<:AbstractEnv} + action = RLBase.plan!(agent.policy, env, p) + push!(agent.trajectory, agent.cache, action) + action +end + +# for simultaneous DynamicStyle environments, we have to define push! operations +function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv, player::Symbol) + push!(agent, state(env, player)) +end + +function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E, player::Symbol) where {P,T,C,E<:AbstractEnv} + push!(agent.cache, reward(env, player), is_terminated(env)) +end + +function Base.push!(agent::Agent, ::PostExperimentStage, env::E, player::Symbol) where {E<:AbstractEnv} + RLBase.reset!(agent.cache) +end diff --git a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl index 1db43a18c..bbf756e03 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl @@ -133,10 +133,12 @@ function Base.run( if check_stop(stop_condition, policy, env) is_stop = true - @timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env) - @timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage()) - @timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env) - @timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation + if !is_terminated(env) + @timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env) + @timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage()) + @timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env) + @timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation + end break end @@ -228,7 +230,7 @@ function Base.push!(composed_hook::ComposedHook{T}, end function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv} - return (RLBase.plan!(multiagent[player], env, player) for player in players(env)) + return NamedTuple(player => RLBase.plan!(multiagent[player], env, player) for player in players(env)) end function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage}