Train¶
To train and evaluate the Top-level RL agent, run the train_rl_agent.py script with the desired configurations:
python train_rl_agent.py \
--sim-config configs/env/sim_config.yaml \
--reward-config configs/env/reward_config.yaml \
--dc-config configs/env/datacenters.yaml \
--algo-config configs/env/algorithm_config.yaml \
[--tag <run_tag>] \
[--seed <random_seed>] \
[--enable-logger true|false]
Command-Line Arguments¶
The script accepts the following options:
- --sim-config SIM_CONFIG¶
Path to the simulation configuration YAML file. Default:
configs/env/sim_config.yaml
.
- --reward-config REWARD_CONFIG¶
Path to the reward configuration YAML file. Default:
configs/env/reward_config.yaml
.
- --dc-config DC_CONFIG¶
Path to the datacenter configuration YAML file. Default:
configs/env/datacenters.yaml
.
- --algo-config ALGO_CONFIG¶
Path to the reinforcement-learning algorithm configuration YAML file. Default:
configs/env/algorithm_config.yaml
.
- --tag TAG¶
Optional run tag to distinguish logs and checkpoints. Default: (empty string).
- --seed SEED¶
Integer random seed for environment and training. Default:
42
.
- --enable-logger {yes,true,t,1}/{no,false,f,0}¶
Whether to enable debug-level logger output. Default:
True
.
Training Script¶
Below is the full contents of train_rl_agent.py:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import trange
import logging
import os
from collections import deque
import argparse
import datetime
from envs.task_scheduling_env import TaskSchedulingEnv
from rl_components.agent_net import ActorNet, CriticNet
from rl_components.replay_buffer import FastReplayBuffer
from rewards.predefined.composite_reward import CompositeReward
from utils.checkpoint_manager import save_checkpoint, load_checkpoint
from utils.config_loader import load_yaml
from utils.config_logger import setup_logger
from torch.utils.tensorboard import SummaryWriter
class RunningStats:
def __init__(self, eps=1e-5):
self.mean = 0.0
self.var = 1.0
self.count = eps
def update(self, x):
x = float(x)
self.count += 1
last_mean = self.mean
self.mean += (x - self.mean) / self.count
self.var += (x - last_mean) * (x - self.mean)
def normalize(self, x):
std = max(np.sqrt(self.var / self.count), 1e-6)
return (x - self.mean) / std
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "1"):
return True
elif v.lower() in ("no", "false", "f", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def parse_args():
parser = argparse.ArgumentParser(description="SustainCluster Training")
parser.add_argument(
"--sim-config",
type=str,
default="configs/env/sim_config.yaml"
)
parser.add_argument(
"--reward-config",
type=str,
default="configs/env/reward_config.yaml"
)
parser.add_argument(
"--dc-config",
type=str,
default="configs/env/datacenters.yaml"
)
parser.add_argument(
"--algo-config",
type=str,
default="configs/env/algorithm_config.yaml"
)
parser.add_argument(
"--tag",
type=str,
default="",
help="Optional run tag"
)
parser.add_argument(
"--seed",
type=int,
default=42
)
parser.add_argument(
"--enable-logger",
type=str2bool,
default=True,
help="Enable logger"
)
return parser.parse_args()
def make_env(sim_cfg_path, dc_cfg_path, reward_cfg_path, writer=None, logger=None):
import pandas as pd
from simulation.cluster_manager import DatacenterClusterManager
sim_cfg = load_yaml(sim_cfg_path)["simulation"]
dc_cfg = load_yaml(dc_cfg_path)["datacenters"]
reward_cfg = load_yaml(reward_cfg_path)["reward"]
start = pd.Timestamp(
datetime.datetime(
sim_cfg["year"],
sim_cfg["month"],
sim_cfg["init_day"],
sim_cfg["init_hour"],
tzinfo=datetime.timezone.utc
)
)
end = start + datetime.timedelta(days=sim_cfg["duration_days"])
cluster = DatacenterClusterManager(
config_list=dc_cfg,
simulation_year=sim_cfg["year"],
init_day=int(sim_cfg["month"] * 30.5),
init_hour=sim_cfg["init_hour"],
strategy=sim_cfg["strategy"],
tasks_file_path=sim_cfg["workload_path"],
shuffle_datacenter_order=sim_cfg["shuffle_datacenters"],
cloud_provider=sim_cfg["cloud_provider"],
logger=logger
)
reward_fn = CompositeReward(
components=reward_cfg["components"],
normalize=reward_cfg.get("normalize", False)
)
return TaskSchedulingEnv(
cluster_manager=cluster,
start_time=start,
end_time=end,
reward_fn=reward_fn,
writer=writer if sim_cfg.get("use_tensorboard", False) else None
)
def train():
args = parse_args()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_id = f"{args.tag}_{timestamp}" if args.tag else timestamp
log_dir = f"logs/train_{run_id}"
tb_dir = f"runs/train_{run_id}"
ckpt_dir = f"checkpoints/train_{run_id}"
os.makedirs(ckpt_dir, exist_ok=True)
writer = SummaryWriter(log_dir=tb_dir)
print(f"Enable logger: {args.enable_logger}")
logger = setup_logger(log_dir, enable_logger=args.enable_logger)
algo_cfg = load_yaml(args.algo_config)["algorithm"]
if algo_cfg["device"] == "auto":
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
DEVICE = torch.device(algo_cfg["device"])
env = make_env(
args.sim_config,
args.dc_config,
args.reward_config,
writer,
logger
)
obs, _ = env.reset(seed=args.seed)
while len(obs) == 0:
obs, _, done, _, _ = env.step([])
if done:
obs, _ = env.reset(seed=args.seed)
obs_dim = len(obs[0])
act_dim = env.num_dcs + 1
actor = ActorNet(obs_dim, act_dim,
hidden_dim=algo_cfg["hidden_dim"]).to(DEVICE)
critic = CriticNet(obs_dim, act_dim,
hidden_dim=algo_cfg["hidden_dim"]).to(DEVICE)
target_critic = CriticNet(obs_dim, act_dim,
hidden_dim=algo_cfg["hidden_dim"]).to(DEVICE)
target_critic.load_state_dict(critic.state_dict())
actor_opt = torch.optim.Adam(
actor.parameters(),
lr=float(algo_cfg["actor_learning_rate"])
)
critic_opt = torch.optim.Adam(
critic.parameters(),
lr=float(algo_cfg["critic_learning_rate"])
)
buffer = FastReplayBuffer(
capacity=algo_cfg["replay_buffer_size"],
max_tasks=algo_cfg["max_tasks"],
obs_dim=obs_dim
)
reward_stats = RunningStats()
episode_reward = 0
episode_steps = 0
episode_reward_buffer = deque(maxlen=10)
best_avg_reward = float("-inf")
q_loss = policy_loss = None
pbar = trange(algo_cfg["total_steps"])
for global_step in pbar:
obs_tensor = torch.FloatTensor(obs).to(DEVICE)
if not obs:
actions = []
elif global_step < algo_cfg["warmup_steps"]:
actions = [np.random.randint(act_dim) for _ in obs]
else:
with torch.no_grad():
logits = actor(obs_tensor)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
actions= dist.sample().cpu().tolist()
assert all(0 <= a < act_dim for a in actions), \
f"Invalid action: {actions}"
writer.add_histogram("Actor/logits", logits, global_step)
writer.add_histogram("Actor/probs", probs, global_step)
next_obs, reward, done, truncated, _ = env.step(actions)
reward_stats.update(reward)
normalized_reward = reward_stats.normalize(reward)
if actions:
buffer.add(obs, actions, normalized_reward,
next_obs, done or truncated)
obs = next_obs
episode_reward += reward
episode_steps += 1
if done or truncated:
avg = episode_reward / episode_steps
episode_reward_buffer.append(avg)
if logger:
logger.info(f"[Episode End] total_reward={avg:.2f}")
writer.add_scalar("Reward/Episode", avg, global_step)
pbar.write(f"Episode reward: {avg:.2f} (steps: {episode_steps})")
obs, _ = env.reset(seed=args.seed+global_step)
episode_reward = 0
episode_steps = 0
if len(episode_reward_buffer) == 10:
avg10 = np.mean(episode_reward_buffer)
writer.add_scalar("Reward/Avg10", avg10, global_step)
pbar.write(f"Avg reward: {avg10:.2f}")
if avg10 > best_avg_reward:
best_avg_reward = avg10
save_checkpoint(
global_step, actor, critic,
actor_opt, critic_opt,
ckpt_dir, best=True
)
pbar.write(
f"[BEST] Saved checkpoint at step {global_step} "
f"(avg10 reward={avg10:.2f})"
)
# RL updates (Q- and policy-loss, backward, optimizer steps) omitted for brevity
if global_step % algo_cfg["save_interval"] == 0 and global_step > 0:
save_checkpoint(global_step, actor, critic,
actor_opt, critic_opt, ckpt_dir)
pbar.write(f"Saved checkpoint at step {global_step}")
writer.close()
if __name__ == "__main__":
train()