harl.models.value_function_models package

Submodules

harl.models.value_function_models.continuous_q_net module

class harl.models.value_function_models.continuous_q_net.ContinuousQNet(args, cent_obs_space, act_spaces, device=device(type='cpu'))[source]

Bases: Module

Q Network for continuous and discrete action space. Outputs the q value given global states and actions. Note that the name ContinuousQNet emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space.

forward(cent_obs, actions)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

harl.models.value_function_models.continuous_q_net.get_combined_dim(cent_obs_feature_dim, act_spaces)[source]

Get the combined dimension of central observation and individual actions.

harl.models.value_function_models.dueling_q_net module

class harl.models.value_function_models.dueling_q_net.DuelingQNet(args, obs_space, output_dim, device=device(type='cpu'))[source]

Bases: Module

Dueling Q Network for discrete action space.

forward(obs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

harl.models.value_function_models.v_net module

class harl.models.value_function_models.v_net.VNet(args, cent_obs_space, device=device(type='cpu'))[source]

Bases: Module

V Network. Outputs value function predictions given global states.

forward(cent_obs, rnn_states, masks)[source]

Compute actions from the given inputs. :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network. :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.

Returns:

(torch.Tensor) value function predictions. rnn_states: (torch.Tensor) updated RNN hidden states.

Return type:

values

Module contents