harl.algorithms.critics package

Submodules

harl.algorithms.critics.continuous_q_critic module

Continuous Q Critic.

class harl.algorithms.critics.continuous_q_critic.ContinuousQCritic(args, share_obs_space, act_space, num_agents, state_type, device=device(type='cpu'))[source]

Bases: object

Continuous Q Critic. Critic that learns a Q-function. The action space is continuous. Note that the name ContinuousQCritic emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space. For now, it only supports continuous action space, but we will enhance its capability to include discrete action space in the future.

get_values(share_obs, actions)[source]

Get the Q values.

lr_decay(step, steps)[source]

Decay the actor and critic learning rates. :param step: (int) current training step. :param steps: (int) total number of training steps.

restore(model_dir)[source]

Restore the model.

save(save_dir)[source]

Save the model.

soft_update()[source]

Soft update the target network.

train(share_obs, actions, reward, done, term, next_share_obs, next_actions, gamma)[source]

Train the critic. :param share_obs: (np.ndarray) shape is (batch_size, dim) :param actions: (np.ndarray) shape is (n_agents, batch_size, dim) :param reward: (np.ndarray) shape is (batch_size, 1) :param done: (np.ndarray) shape is (batch_size, 1) :param term: (np.ndarray) shape is (batch_size, 1) :param next_share_obs: (np.ndarray) shape is (batch_size, dim) :param next_actions: (np.ndarray) shape is (n_agents, batch_size, dim) :param gamma: (np.ndarray) shape is (batch_size, 1)

turn_off_grad()[source]

Turn off the gradient for the critic.

turn_on_grad()[source]

Turn on the gradient for the critic.

harl.algorithms.critics.discrete_q_critic module

Discrete Q Critic.

class harl.algorithms.critics.discrete_q_critic.DiscreteQCritic(args, share_obs_space, act_space, num_agents, state_type, device=device(type='cpu'))[source]

Bases: object

Discrete Q Critic. Critic that learns a Q-function. The action space is discrete.

get_joint_idx(actions, agent_id)[source]

Get available joint idx for an agent. All other agents keep their current actions, and this agent can freely choose. :param actions: (list) individual actions. :param agent_id: (int) agent id.

Returns:

(torch.Tensor) shape is (batch_size, self.action_dims[agent_id])

Return type:

joint_idx

get_values(share_obs, actions)[source]

Get values for given observations and actions.

indiv_to_joint(orig_actions)[source]

Convert individual actions to joint action. :param orig_action: (int) joint action.

Returns:

(list) individual actions.

Return type:

actions

For example, if agents’ action_dims are [4, 3], then: joint action 0 <–> indiv actions [0, 0], joint action 1 <–> indiv actions [0, 1], …… joint action 5 <–> indiv actions [1, 2], …… joint action 11 <–> indiv actions [3, 2].

joint_to_indiv(orig_action)[source]

Convert joint action to individual actions. :param orig_action: (int) joint action.

Returns:

(list) individual actions.

Return type:

actions

For example, if agents’ action_dims are [4, 3], then: joint action 0 <–> indiv actions [0, 0], joint action 1 <–> indiv actions [1, 0], …… joint action 5 <–> indiv actions [1, 1], …… joint action 11 <–> indiv actions [3, 2].

lr_decay(step, steps)[source]

Decay the actor and critic learning rates. :param step: (int) current training step. :param steps: (int) total number of training steps.

process_action_spaces(action_spaces)[source]

Process action spaces.

restore(model_dir)[source]

Restore model parameters.

save(save_dir)[source]

Save model parameters.

soft_update()[source]

Soft update the target network.

train(share_obs, actions, reward, done, term, next_share_obs, next_actions, gamma)[source]

Update the critic. :param share_obs: (np.ndarray) shape is (batch_size, dim) :param actions: (np.ndarray) shape is (n_agents, batch_size, dim) :param reward: (np.ndarray) shape is (batch_size, 1) :param done: (np.ndarray) shape is (batch_size, 1) :param term: (np.ndarray) shape is (batch_size, 1) :param next_share_obs: (np.ndarray) shape is (batch_size, dim) :param next_actions: (np.ndarray) shape is (n_agents, batch_size, dim) :param gamma: (np.ndarray) shape is (batch_size, 1)

train_values(share_obs, actions)[source]

Train the critic. :param share_obs: shape is (batch_size, dim) :param actions: shape is (n_agents, batch_size, dim)

turn_off_grad()[source]

Turn off gradient for critic.

turn_on_grad()[source]

Turn on gradient for critic.

harl.algorithms.critics.soft_twin_continuous_q_critic module

Soft Twin Continuous Q Critic.

class harl.algorithms.critics.soft_twin_continuous_q_critic.SoftTwinContinuousQCritic(args, share_obs_space, act_space, num_agents, state_type, device=device(type='cpu'))[source]

Bases: TwinContinuousQCritic

Soft Twin Continuous Q Critic. Critic that learns two soft Q-functions. The action space can be continuous and discrete. Note that the name SoftTwinContinuousQCritic emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space.

get_values(share_obs, actions)[source]

Get the soft Q values for the given observations and actions.

train(share_obs, actions, reward, done, valid_transition, term, next_share_obs, next_actions, next_logp_actions, gamma, value_normalizer=None)[source]

Train the critic. :param share_obs: EP: (batch_size, dim), FP: (n_agents * batch_size, dim) :param actions: (n_agents, batch_size, dim) :param reward: EP: (batch_size, 1), FP: (n_agents * batch_size, 1) :param done: EP: (batch_size, 1), FP: (n_agents * batch_size, 1) :param valid_transition: (n_agents, batch_size, 1) :param term: EP: (batch_size, 1), FP: (n_agents * batch_size, 1) :param next_share_obs: EP: (batch_size, dim), FP: (n_agents * batch_size, dim) :param next_actions: (n_agents, batch_size, dim) :param next_logp_actions: (n_agents, batch_size, 1) :param gamma: EP: (batch_size, 1), FP: (n_agents * batch_size, 1) :param value_normalizer: (ValueNorm) normalize the rewards, denormalize critic outputs.

update_alpha(logp_actions, target_entropy)[source]

Auto-tune the temperature parameter alpha.

harl.algorithms.critics.twin_continuous_q_critic module

Twin Continuous Q Critic.

class harl.algorithms.critics.twin_continuous_q_critic.TwinContinuousQCritic(args, share_obs_space, act_space, num_agents, state_type, device=device(type='cpu'))[source]

Bases: object

Twin Continuous Q Critic. Critic that learns two Q-functions. The action space is continuous. Note that the name TwinContinuousQCritic emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space. For now, it only supports continuous action space, but we will enhance its capability to include discrete action space in the future.

get_values(share_obs, actions)[source]

Get the Q values for the given observations and actions.

lr_decay(step, steps)[source]

Decay the actor and critic learning rates. :param step: (int) current training step. :param steps: (int) total number of training steps.

restore(model_dir)[source]

Restore the model parameters.

save(save_dir)[source]

Save the model parameters.

soft_update()[source]

Soft update the target networks.

train(share_obs, actions, reward, done, term, next_share_obs, next_actions, gamma)[source]

Train the critic. :param share_obs: (np.ndarray) shape is (batch_size, dim) :param actions: (np.ndarray) shape is (n_agents, batch_size, dim) :param reward: (np.ndarray) shape is (batch_size, 1) :param done: (np.ndarray) shape is (batch_size, 1) :param term: (np.ndarray) shape is (batch_size, 1) :param next_share_obs: (np.ndarray) shape is (batch_size, dim) :param next_actions: (np.ndarray) shape is (n_agents, batch_size, dim) :param gamma: (np.ndarray) shape is (batch_size, 1)

turn_off_grad()[source]

Turn off the gradient for the critic network.

turn_on_grad()[source]

Turn on the gradient for the critic network.

harl.algorithms.critics.v_critic module

V Critic.

class harl.algorithms.critics.v_critic.VCritic(args, cent_obs_space, device=device(type='cpu'))[source]

Bases: object

V Critic. Critic that learns a V-function.

cal_value_loss(values, value_preds_batch, return_batch, value_normalizer=None)[source]

Calculate value function loss. :param values: (torch.Tensor) value function predictions. :param value_preds_batch: (torch.Tensor) “old” value predictions from data batch (used for value clip loss) :param return_batch: (torch.Tensor) reward to go returns. :param value_normalizer: (ValueNorm) normalize the rewards, denormalize critic outputs.

Returns:

(torch.Tensor) value function loss.

Return type:

value_loss

get_values(cent_obs, rnn_states_critic, masks)[source]

Get value function predictions. :param cent_obs: (np.ndarray) centralized input to the critic. :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. :param masks: (np.ndarray) denotes points at which RNN states should be reset.

Returns:

(torch.Tensor) value function predictions. rnn_states_critic: (torch.Tensor) updated critic network RNN states.

Return type:

values

lr_decay(episode, episodes)[source]

Decay the actor and critic learning rates. :param episode: (int) current training episode. :param episodes: (int) total number of training episodes.

prep_rollout()[source]

Prepare for rollout.

prep_training()[source]

Prepare for training.

train(critic_buffer, value_normalizer=None)[source]

Perform a training update using minibatch GD. :param critic_buffer: (OnPolicyCriticBufferEP or OnPolicyCriticBufferFP) buffer containing training data related to critic. :param value_normalizer: (ValueNorm) normalize the rewards, denormalize critic outputs.

Returns:

(dict) contains information regarding training update (e.g. loss, grad norms, etc).

Return type:

train_info

update(sample, value_normalizer=None)[source]

Update critic network. :param sample: (Tuple) contains data batch with which to update networks. :param value_normalizer: (ValueNorm) normalize the rewards, denormalize critic outputs.

Returns:

(torch.Tensor) value function loss. critic_grad_norm: (torch.Tensor) gradient norm from critic update.

Return type:

value_loss

Module contents

Critic registry.