"""Tools for loading and updating configs."""
import time
import os
import json
import yaml
from uu import Error
[docs]
def get_defaults_yaml_args(algo, env):
"""Load config file for user-specified algo and env.
Args:
algo: (str) Algorithm name.
env: (str) Environment name.
Returns:
algo_args: (dict) Algorithm config.
env_args: (dict) Environment config.
"""
base_path = os.path.split(os.path.dirname(os.path.abspath(__file__)))[0]
algo_cfg_path = os.path.join(base_path, "configs", "algos_cfgs", f"{algo}.yaml")
env_cfg_path = os.path.join(base_path, "configs", "envs_cfgs", f"{env}.yaml")
with open(algo_cfg_path, "r", encoding="utf-8") as file:
algo_args = yaml.load(file, Loader=yaml.FullLoader)
with open(env_cfg_path, "r", encoding="utf-8") as file:
env_args = yaml.load(file, Loader=yaml.FullLoader)
return algo_args, env_args
[docs]
def update_args(unparsed_dict, *args):
"""Update loaded config with unparsed command-line arguments.
Args:
unparsed_dict: (dict) Unparsed command-line arguments.
*args: (list[dict]) argument dicts to be updated.
"""
def update_dict(dict1, dict2):
for k in dict2:
if type(dict2[k]) is dict:
update_dict(dict1, dict2[k])
else:
if k in dict1:
dict2[k] = dict1[k]
for args_dict in args:
update_dict(unparsed_dict, args_dict)
[docs]
def get_task_name(env, env_args):
"""Get task name."""
if env == 'sustaindc':
task = env_args["location"]
else:
raise ValueError('Environment not defined')
return task
[docs]
def init_dir(env, env_args, algo, exp_name, seed, logger_path):
"""Init directory for saving results."""
task = get_task_name(env, env_args)
hms_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
results_path = os.path.join(
logger_path,
env,
task,
algo,
exp_name,
"-".join(["seed-{:0>5}".format(seed), hms_time]),
)
log_path = os.path.join(results_path, "logs")
os.makedirs(log_path, exist_ok=True)
from tensorboardX import SummaryWriter
writter = SummaryWriter(log_path)
models_path = os.path.join(results_path, "models")
os.makedirs(models_path, exist_ok=True)
return results_path, log_path, models_path, writter
[docs]
def is_json_serializable(value):
"""Check if v is JSON serializable."""
try:
json.dumps(value)
return True
except Error:
return False
[docs]
def convert_json(obj):
"""Convert obj to a version which can be serialized with JSON."""
if is_json_serializable(obj):
return obj
else:
if isinstance(obj, dict):
return {convert_json(k): convert_json(v) for k, v in obj.items()}
elif isinstance(obj, tuple):
return (convert_json(x) for x in obj)
elif isinstance(obj, list):
return [convert_json(x) for x in obj]
elif hasattr(obj, "__name__") and not ("lambda" in obj.__name__):
return convert_json(obj.__name__)
elif hasattr(obj, "__dict__") and obj.__dict__:
obj_dict = {
convert_json(k): convert_json(v) for k, v in obj.__dict__.items()
}
return {str(obj): obj_dict}
return str(obj)
[docs]
def save_config(args, algo_args, env_args, run_dir):
"""Save the configuration of the program."""
config = {"main_args": args, "algo_args": algo_args, "env_args": env_args}
config_json = convert_json(config)
output = json.dumps(config_json, separators=(",", ":\t"), indent=4, sort_keys=True)
with open(os.path.join(run_dir, "config.json"), "w", encoding="utf-8") as out:
out.write(output)