[docs]defget_combined_dim(cent_obs_feature_dim,act_spaces):"""Get the combined dimension of central observation and individual actions."""combined_dim=cent_obs_feature_dimforspaceinact_spaces:ifspace.__class__.__name__=="Box":combined_dim+=space.shape[0]elifspace.__class__.__name__=="Discrete":combined_dim+=space.nelse:action_dims=space.nvecforaction_diminaction_dims:combined_dim+=action_dimreturncombined_dim
[docs]classContinuousQNet(nn.Module):"""Q Network for continuous and discrete action space. Outputs the q value given global states and actions. Note that the name ContinuousQNet emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space. """def__init__(self,args,cent_obs_space,act_spaces,device=torch.device("cpu")):super(ContinuousQNet,self).__init__()activation_func=args["activation_func"]hidden_sizes=args["hidden_sizes"]cent_obs_shape=get_shape_from_obs_space(cent_obs_space)iflen(cent_obs_shape)==3:self.feature_extractor=PlainCNN(cent_obs_shape,hidden_sizes[0],activation_func)cent_obs_feature_dim=hidden_sizes[0]else:self.feature_extractor=Nonecent_obs_feature_dim=cent_obs_shape[0]sizes=([get_combined_dim(cent_obs_feature_dim,act_spaces)]+list(hidden_sizes)+[1])self.mlp=PlainMLP(sizes,activation_func)self.to(device)