"""Tools for HARL."""importosimportrandomimportnumpyasnpimporttorchfromharl.envs.env_wrappersimportShareSubprocVecEnv,ShareDummyVecEnv
[docs]defcheck(value):"""Check if value is a numpy array, if so, convert it to a torch tensor."""output=torch.from_numpy(value)ifisinstance(value,np.ndarray)elsevaluereturnoutput
[docs]defget_shape_from_obs_space(obs_space):"""Get shape from observation space. Args: obs_space: (gym.spaces or list) observation space Returns: obs_shape: (tuple) observation shape """ifobs_space.__class__.__name__=="Box":obs_shape=obs_space.shapeelifobs_space.__class__.__name__=="list":obs_shape=obs_spaceelse:raiseNotImplementedErrorreturnobs_shape
[docs]defget_shape_from_act_space(act_space):"""Get shape from action space. Args: act_space: (gym.spaces) action space Returns: act_shape: (tuple) action shape """ifact_space.__class__.__name__=="Discrete":act_shape=1elifact_space.__class__.__name__=="MultiDiscrete":act_shape=act_space.shape[0]elifact_space.__class__.__name__=="Box":act_shape=act_space.shape[0]elifact_space.__class__.__name__=="MultiBinary":act_shape=act_space.shape[0]returnact_shape
[docs]defmake_train_env(env_name,seed,n_threads,env_args):"""Make env for training."""defget_env_fn(rank):definit_env():ifenv_name=='sustaindc':fromharl.envs.sustaindc.harlsustaindc_envimportHARLSustainDCEnvif'month'inenv_args:env_args['month']=env_args['month']elifrank<12:env_args['month']=rank%12else:# 33% June (5), 33% July (6), 33% August (7)env_args['month']=rank%3+5env=HARLSustainDCEnv(env_args)else:print("Can not support the "+env_name+"environment.")raiseNotImplementedErrorenv.seed(seed+rank*1000)returnenvreturninit_envifn_threads==1:returnShareDummyVecEnv([get_env_fn(0)])else:returnShareSubprocVecEnv([get_env_fn(i)foriinrange(n_threads)])
[docs]defmake_eval_env(env_name,seed,n_threads,env_args):"""Make env for evaluation."""defget_env_fn(rank):definit_env():ifenv_name=='sustaindc':fromharl.envs.sustaindc.harlsustaindc_envimportHARLSustainDCEnvif'month'inenv_args:env_args['month']=env_args['month']elifrank<12:env_args['month']=rank%12else:# 33% June (5), 33% July (6), 33% August (7)env_args['month']=rank%3+5env=HARLSustainDCEnv(env_args)else:print("Can not support the "+env_name+"environment.")raiseNotImplementedErrorenv.seed(seed*50000+rank*10000)returnenvreturninit_envifn_threads==1:returnShareDummyVecEnv([get_env_fn(0)])else:returnShareSubprocVecEnv([get_env_fn(i)foriinrange(n_threads)])
[docs]defmake_render_env(env_name,seed,env_args):"""Make env for rendering."""manual_render=True# manually call the render() functionmanual_expand_dims=True# manually expand the num_of_parallel_envs dimensionmanual_delay=True# manually delay the rendering by time.sleep()env_num=1# number of parallel envsprint("Can not support the "+env_name+"environment.")raiseNotImplementedErrorreturnenv,manual_render,manual_expand_dims,manual_delay,env_num
[docs]defset_seed(args):"""Seed the program."""ifnotargs["seed_specify"]:args["seed"]=np.random.randint(1000,10000)random.seed(args["seed"])np.random.seed(args["seed"])os.environ["PYTHONHASHSEED"]=str(args["seed"])torch.manual_seed(args["seed"])torch.cuda.manual_seed(args["seed"])torch.cuda.manual_seed_all(args["seed"])
[docs]defget_num_agents(env,env_args,envs):"""Get the number of agents in the environment."""ifenv=='sustaindc':returnenvs.n_agentselse:raiseValueError(f"Unsupported environment type: '{env}'. Check the environment name and try again.")