Source code for harl.common.buffers.on_policy_critic_buffer_ep
"""On-policy buffer for critic that uses Environment-Provided (EP) state."""importtorchimportnumpyasnpfromharl.utils.envs_toolsimportget_shape_from_obs_spacefromharl.utils.trans_toolsimport_flatten,_sa_cast
[docs]classOnPolicyCriticBufferEP:"""On-policy buffer for critic that uses Environment-Provided (EP) state."""def__init__(self,args,share_obs_space):"""Initialize on-policy critic buffer. Args: args: (dict) arguments share_obs_space: (gym.Space or list) share observation space """self.episode_length=args["episode_length"]self.n_rollout_threads=args["n_rollout_threads"]self.hidden_sizes=args["hidden_sizes"]self.rnn_hidden_size=self.hidden_sizes[-1]self.recurrent_n=args["recurrent_n"]self.gamma=args["gamma"]self.gae_lambda=args["gae_lambda"]self.use_gae=args["use_gae"]self.use_proper_time_limits=args["use_proper_time_limits"]share_obs_shape=get_shape_from_obs_space(share_obs_space)ifisinstance(share_obs_shape[-1],list):share_obs_shape=share_obs_shape[:1]# Buffer for share observationsself.share_obs=np.zeros((self.episode_length+1,self.n_rollout_threads,*share_obs_shape),dtype=np.float32,)# Buffer for rnn states of criticself.rnn_states_critic=np.zeros((self.episode_length+1,self.n_rollout_threads,self.recurrent_n,self.rnn_hidden_size,),dtype=np.float32,)# Buffer for value predictions made by this criticself.value_preds=np.zeros((self.episode_length+1,self.n_rollout_threads,1),dtype=np.float32)# Buffer for returns calculated at each timestepself.returns=np.zeros((self.episode_length+1,self.n_rollout_threads,1),dtype=np.float32)# Buffer for rewards received by agents at each timestepself.rewards=np.zeros((self.episode_length,self.n_rollout_threads,1),dtype=np.float32)# Buffer for masks indicating whether an episode is done at each timestepself.masks=np.ones((self.episode_length+1,self.n_rollout_threads,1),dtype=np.float32)# Buffer for bad masks indicating truncation and termination. If 0, trunction; if 1 and masks is 0, termination; else, not done yet.self.bad_masks=np.ones_like(self.masks)self.step=0
[docs]definsert(self,share_obs,rnn_states_critic,value_preds,rewards,masks,bad_masks):"""Insert data into buffer."""self.share_obs[self.step+1]=share_obs.copy()self.rnn_states_critic[self.step+1]=rnn_states_critic.copy()self.value_preds[self.step]=value_preds.copy()self.rewards[self.step]=rewards.copy()self.masks[self.step+1]=masks.copy()self.bad_masks[self.step+1]=bad_masks.copy()self.step=(self.step+1)%self.episode_length
[docs]defafter_update(self):"""After an update, copy the data at the last step to the first position of the buffer."""self.share_obs[0]=self.share_obs[-1].copy()self.rnn_states_critic[0]=self.rnn_states_critic[-1].copy()self.masks[0]=self.masks[-1].copy()self.bad_masks[0]=self.bad_masks[-1].copy()
[docs]defget_mean_rewards(self):"""Get mean rewards for logging."""returnnp.mean(self.rewards)
[docs]defcompute_returns(self,next_value,value_normalizer=None):"""Compute returns either as discounted sum of rewards, or using GAE. Args: next_value: (np.ndarray) value predictions for the step after the last episode step. value_normalizer: (ValueNorm) If not None, ValueNorm value normalizer instance. """if(self.use_proper_time_limits):# consider the difference between truncation and terminationifself.use_gae:# use GAEself.value_preds[-1]=next_valuegae=0forstepinreversed(range(self.rewards.shape[0])):ifvalue_normalizerisnotNone:# use ValueNormdelta=(self.rewards[step]+self.gamma*value_normalizer.denormalize(self.value_preds[step+1])*self.masks[step+1]-value_normalizer.denormalize(self.value_preds[step]))gae=(delta+self.gamma*self.gae_lambda*self.masks[step+1]*gae)gae=self.bad_masks[step+1]*gaeself.returns[step]=gae+value_normalizer.denormalize(self.value_preds[step])else:# do not use ValueNormdelta=(self.rewards[step]+self.gamma*self.value_preds[step+1]*self.masks[step+1]-self.value_preds[step])gae=(delta+self.gamma*self.gae_lambda*self.masks[step+1]*gae)gae=self.bad_masks[step+1]*gaeself.returns[step]=gae+self.value_preds[step]else:# do not use GAEself.returns[-1]=next_valueforstepinreversed(range(self.rewards.shape[0])):ifvalue_normalizerisnotNone:# use ValueNormself.returns[step]=(self.returns[step+1]*self.gamma*self.masks[step+1]+self.rewards[step])*self.bad_masks[step+1]+(1-self.bad_masks[step+1])*value_normalizer.denormalize(self.value_preds[step])else:# do not use ValueNormself.returns[step]=(self.returns[step+1]*self.gamma*self.masks[step+1]+self.rewards[step])*self.bad_masks[step+1]+(1-self.bad_masks[step+1])*self.value_preds[step]else:# do not consider the difference between truncation and termination, i.e. all done episodes are terminatedifself.use_gae:# use GAEself.value_preds[-1]=next_valuegae=0forstepinreversed(range(self.rewards.shape[0])):ifvalue_normalizerisnotNone:# use ValueNormdelta=(self.rewards[step]+self.gamma*value_normalizer.denormalize(self.value_preds[step+1])*self.masks[step+1]-value_normalizer.denormalize(self.value_preds[step]))gae=(delta+self.gamma*self.gae_lambda*self.masks[step+1]*gae)self.returns[step]=gae+value_normalizer.denormalize(self.value_preds[step])else:# do not use ValueNormdelta=(self.rewards[step]+self.gamma*self.value_preds[step+1]*self.masks[step+1]-self.value_preds[step])gae=(delta+self.gamma*self.gae_lambda*self.masks[step+1]*gae)self.returns[step]=gae+self.value_preds[step]else:# do not use GAEself.returns[-1]=next_valueforstepinreversed(range(self.rewards.shape[0])):self.returns[step]=(self.returns[step+1]*self.gamma*self.masks[step+1]+self.rewards[step])
[docs]deffeed_forward_generator_critic(self,critic_num_mini_batch=None,mini_batch_size=None):"""Training data generator for critic that uses MLP network. Args: critic_num_mini_batch: (int) Number of mini batches for critic. mini_batch_size: (int) Size of mini batch for critic. """# get episode_length, n_rollout_threads, mini_batch_sizeepisode_length,n_rollout_threads=self.rewards.shape[0:2]batch_size=n_rollout_threads*episode_lengthifmini_batch_sizeisNone:assertbatch_size>=critic_num_mini_batch,(f"The number of processes ({n_rollout_threads}) "f"* number of steps ({episode_length}) = {n_rollout_threads*episode_length} "f"is required to be greater than or equal to the number of critic mini batches ({critic_num_mini_batch}).")mini_batch_size=batch_size//critic_num_mini_batch# shuffle indicesrand=torch.randperm(batch_size).numpy()sampler=[rand[i*mini_batch_size:(i+1)*mini_batch_size]foriinrange(critic_num_mini_batch)]# Combine the first two dimensions (episode_length and n_rollout_threads) to form batch.# Take share_obs shape as an example:# (episode_length + 1, n_rollout_threads, *share_obs_shape) --> (episode_length, n_rollout_threads, *share_obs_shape)# --> (episode_length * n_rollout_threads, *share_obs_shape)share_obs=self.share_obs[:-1].reshape(-1,*self.share_obs.shape[2:])rnn_states_critic=self.rnn_states_critic[:-1].reshape(-1,*self.rnn_states_critic.shape[2:])# actually not used, just for consistencyvalue_preds=self.value_preds[:-1].reshape(-1,1)returns=self.returns[:-1].reshape(-1,1)masks=self.masks[:-1].reshape(-1,1)forindicesinsampler:# share_obs shape:# (episode_length * n_rollout_threads, *share_obs_shape) --> (mini_batch_size, *share_obs_shape)share_obs_batch=share_obs[indices]rnn_states_critic_batch=rnn_states_critic[indices]value_preds_batch=value_preds[indices]return_batch=returns[indices]masks_batch=masks[indices]yieldshare_obs_batch,rnn_states_critic_batch,value_preds_batch,return_batch,masks_batch
[docs]defnaive_recurrent_generator_critic(self,critic_num_mini_batch):"""Training data generator for critic that uses RNN network. This generator does not split the trajectories into chunks, and therefore maybe less efficient than the recurrent_generator_critic in training. Args: critic_num_mini_batch: (int) Number of mini batches for critic. """# get n_rollout_threads and num_envs_per_batchn_rollout_threads=self.rewards.shape[1]assertn_rollout_threads>=critic_num_mini_batch,(f"The number of processes ({n_rollout_threads}) "f"has to be greater than or equal to the number of "f"mini batches ({critic_num_mini_batch}).")num_envs_per_batch=n_rollout_threads//critic_num_mini_batch# shuffle indicesperm=torch.randperm(n_rollout_threads).numpy()T,N=self.episode_length,num_envs_per_batchforbatch_idinrange(critic_num_mini_batch):start_id=batch_id*num_envs_per_batchids=perm[start_id:start_id+num_envs_per_batch]share_obs_batch=_flatten(T,N,self.share_obs[:-1,ids])value_preds_batch=_flatten(T,N,self.value_preds[:-1,ids])return_batch=_flatten(T,N,self.returns[:-1,ids])masks_batch=_flatten(T,N,self.masks[:-1,ids])rnn_states_critic_batch=self.rnn_states_critic[0,ids]yieldshare_obs_batch,rnn_states_critic_batch,value_preds_batch,return_batch,masks_batch
[docs]defrecurrent_generator_critic(self,critic_num_mini_batch,data_chunk_length):"""Training data generator for critic that uses RNN network. This generator splits the trajectories into chunks of length data_chunk_length, and therefore maybe more efficient than the naive_recurrent_generator_actor in training. Args: critic_num_mini_batch: (int) Number of mini batches for critic. data_chunk_length: (int) Length of data chunks. """# get episode_length, n_rollout_threads, and mini_batch_sizeepisode_length,n_rollout_threads=self.rewards.shape[0:2]batch_size=n_rollout_threads*episode_lengthdata_chunks=batch_size//data_chunk_lengthmini_batch_size=data_chunks//critic_num_mini_batchassert(episode_length%data_chunk_length==0),f"episode length ({episode_length}) must be a multiple of data chunk length ({data_chunk_length})."assertdata_chunks>=2,"need larger batch size"# shuffle indicesrand=torch.randperm(data_chunks).numpy()sampler=[rand[i*mini_batch_size:(i+1)*mini_batch_size]foriinrange(critic_num_mini_batch)]# The following data operations first transpose the first two dimensions of the data (episode_length, n_rollout_threads)# to (n_rollout_threads, episode_length), then reshape the data to (n_rollout_threads * episode_length, *dim).# Take share_obs shape as an example:# (episode_length + 1, n_rollout_threads, *share_obs_shape) --> (episode_length, n_rollout_threads, *share_obs_shape)# --> (n_rollout_threads, episode_length, *share_obs_shape) --> (n_rollout_threads * episode_length, *share_obs_shape)iflen(self.share_obs.shape)>3:share_obs=(self.share_obs[:-1].transpose(1,0,2,3,4).reshape(-1,*self.share_obs.shape[2:]))else:share_obs=_sa_cast(self.share_obs[:-1])value_preds=_sa_cast(self.value_preds[:-1])returns=_sa_cast(self.returns[:-1])masks=_sa_cast(self.masks[:-1])rnn_states_critic=(self.rnn_states_critic[:-1].transpose(1,0,2,3).reshape(-1,*self.rnn_states_critic.shape[2:]))# generate mini-batchesforindicesinsampler:share_obs_batch=[]rnn_states_critic_batch=[]value_preds_batch=[]return_batch=[]masks_batch=[]forindexinindices:ind=index*data_chunk_lengthshare_obs_batch.append(share_obs[ind:ind+data_chunk_length])value_preds_batch.append(value_preds[ind:ind+data_chunk_length])return_batch.append(returns[ind:ind+data_chunk_length])masks_batch.append(masks[ind:ind+data_chunk_length])rnn_states_critic_batch.append(rnn_states_critic[ind])# only the beginning rnn states are neededL,N=data_chunk_length,mini_batch_size# These are all ndarrays of size (data_chunk_length, mini_batch_size, *dim)share_obs_batch=np.stack(share_obs_batch,axis=1)value_preds_batch=np.stack(value_preds_batch,axis=1)return_batch=np.stack(return_batch,axis=1)masks_batch=np.stack(masks_batch,axis=1)# rnn_states_critic_batch is a (mini_batch_size, *dim) ndarrayrnn_states_critic_batch=np.stack(rnn_states_critic_batch).reshape(N,*self.rnn_states_critic.shape[2:])# Flatten the (data_chunk_length, mini_batch_size, *dim) ndarrays to (data_chunk_length * mini_batch_size, *dim)share_obs_batch=_flatten(L,N,share_obs_batch)value_preds_batch=_flatten(L,N,value_preds_batch)return_batch=_flatten(L,N,return_batch)masks_batch=_flatten(L,N,masks_batch)yieldshare_obs_batch,rnn_states_critic_batch,value_preds_batch,return_batch,masks_batch