Source code for harl.models.value_function_models.dueling_q_net

import torch
import torch.nn as nn
from harl.models.base.plain_cnn import PlainCNN
from harl.models.base.plain_mlp import PlainMLP
from harl.utils.envs_tools import get_shape_from_obs_space


[docs] class DuelingQNet(nn.Module): """Dueling Q Network for discrete action space.""" def __init__(self, args, obs_space, output_dim, device=torch.device("cpu")): super().__init__() self.tpdv = dict(dtype=torch.float32, device=device) base_hidden_sizes = args["base_hidden_sizes"] base_activation_func = args["base_activation_func"] dueling_v_hidden_sizes = args["dueling_v_hidden_sizes"] dueling_v_activation_func = args["dueling_v_activation_func"] dueling_a_hidden_sizes = args["dueling_a_hidden_sizes"] dueling_a_activation_func = args["dueling_a_activation_func"] obs_shape = get_shape_from_obs_space(obs_space) # feature extractor if len(obs_shape) == 3: self.feature_extractor = PlainCNN( obs_shape, base_hidden_sizes[0], base_activation_func ) feature_dim = base_hidden_sizes[0] else: self.feature_extractor = None feature_dim = obs_shape[0] # base base_sizes = [feature_dim] + list(base_hidden_sizes) self.base = PlainMLP(base_sizes, base_activation_func, base_activation_func) # dueling v dueling_v_sizes = [base_hidden_sizes[-1]] + list(dueling_v_hidden_sizes) + [1] self.dueling_v = PlainMLP(dueling_v_sizes, dueling_v_activation_func) # dueling a dueling_a_sizes = ( [base_hidden_sizes[-1]] + list(dueling_a_hidden_sizes) + [output_dim] ) self.dueling_a = PlainMLP(dueling_a_sizes, dueling_a_activation_func) self.to(device)
[docs] def forward(self, obs): if self.feature_extractor is not None: x = self.feature_extractor(obs) else: x = obs x = self.base(x) v = self.dueling_v(x) a = self.dueling_a(x) return a - a.mean(dim=-1, keepdim=True) + v