harl.models.value_function_models package¶
Submodules¶
harl.models.value_function_models.continuous_q_net module¶
- class harl.models.value_function_models.continuous_q_net.ContinuousQNet(args, cent_obs_space, act_spaces, device=device(type='cpu'))[source]¶
Bases:
Module
Q Network for continuous and discrete action space. Outputs the q value given global states and actions. Note that the name ContinuousQNet emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space.
- forward(cent_obs, actions)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
harl.models.value_function_models.dueling_q_net module¶
- class harl.models.value_function_models.dueling_q_net.DuelingQNet(args, obs_space, output_dim, device=device(type='cpu'))[source]¶
Bases:
Module
Dueling Q Network for discrete action space.
- forward(obs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
harl.models.value_function_models.v_net module¶
- class harl.models.value_function_models.v_net.VNet(args, cent_obs_space, device=device(type='cpu'))[source]¶
Bases:
Module
V Network. Outputs value function predictions given global states.
- forward(cent_obs, rnn_states, masks)[source]¶
Compute actions from the given inputs. :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network. :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.
- Returns:
(torch.Tensor) value function predictions. rnn_states: (torch.Tensor) updated RNN hidden states.
- Return type:
values