Source code for envs.task_scheduling_env

import numpy as np
import pandas as pd
from datetime import datetime
import gymnasium as gym
from gymnasium import spaces
from rewards.base_reward import BaseReward
from torch.utils.tensorboard import SummaryWriter  # if not already imported
from data.network_cost.network_delay import get_transmission_delay

[docs] class TaskSchedulingEnv(gym.Env): """ RL Environment for global task scheduling across distributed datacenters. This environment wraps around DatacenterClusterManager and exposes a Gym-compatible interface. It manages task-level actions (assignment or defer), computes observations, and tracks rewards via a modular reward function. RL agents interact with this class. """ def __init__(self, cluster_manager, start_time, end_time, reward_fn: BaseReward, writer: SummaryWriter = None,): super().__init__() self.cluster_manager = cluster_manager self.logger = getattr(self.cluster_manager, "logger", None) self.start_time = start_time self.end_time = end_time self.time_step = pd.Timedelta(minutes=15) self.current_time = self.start_time self.reward_fn = reward_fn self.writer = writer self.pending_tasks = [] self.deferred_tasks = [] # queue of (arrival_time: Timestamp, task: Task, dest_dc_name: str) self.in_transit_tasks = [] self.current_task = None self.global_step = 0 # Used to track time for TensorBoard logs # Set dynamically based on number of DCs self.num_dcs = len(self.cluster_manager.datacenters) # Observation space: [4 sin/cos features, 4 task features features, 5 * num_dcs task features] obs_dim = 4 + 5 + 5 * self.num_dcs self.observation_space = spaces.Box( low=0, high=np.inf, shape=(obs_dim,), dtype=np.float32 ) self.action_space = None # Variable-length batch of int (one per task)
[docs] def reset(self, seed=None, options=None): super().reset(seed=seed) self.current_time = self.start_time # Reset the cluster manager and all datacenters self.cluster_manager.reset(seed=seed) # Load the first batch of tasks self._load_new_tasks() return self._get_obs(), {}
[docs] def step(self, actions): """ actions: list[int] of length == len(self.current_tasks) Each element is the index of the destination datacenter (0-based) """ # === Deliver any in‑flight (transmitting) tasks whose arrival time has come === remaining = [] for arrival_time, task, dc_name in self.in_transit_tasks: if arrival_time <= self.current_time: # now it appears in the destination DC’s pending queue self.cluster_manager.datacenters[dc_name].pending_tasks.append(task) if self.logger: self.logger.info(f"[{self.current_time}] Task {task.job_name} arrived at {dc_name}") else: remaining.append((arrival_time, task, dc_name)) self.in_transit_tasks = remaining if self.cluster_manager.strategy == "manual_rl": assert len(actions) == len(self.current_tasks), \ f"Expected {len(self.current_tasks)} actions, got {len(actions)}" dc_list = list(self.cluster_manager.datacenters.values()) # === Route each task to its assigned destination DC === for task, action in zip(self.current_tasks, actions): # Check if the task has exceeded its SLA deadline if self.current_time > task.sla_deadline: # Enforce computation at origin datacenter origin_dc = next(dc for dc in self.cluster_manager.datacenters.values() if dc.dc_id == task.origin_dc_id) origin_dc.pending_tasks.append(task) task.dest_dc_id = origin_dc.dc_id task.dest_dc = origin_dc if self.logger: self.logger.info( f"[{self.current_time}] Task {task.job_name} exceeded SLA deadline. " f"Forced to origin DC{origin_dc.dc_id}." ) continue # === Temporal deferral === if action == 0: self.deferred_tasks.append(task) task.temporarily_deferred = True if self.logger: self.logger.info( f"[{self.current_time}] Task {task.job_name}, with origin DC{task.origin_dc_id}, " "has been deferred in time (not assigned destination DC)." ) continue # === Geographical routing === dest_dc = dc_list[action - 1] # Now action ∈ [1..num_dcs] # dest_dc.pending_tasks.append(task) # Assign the destination to the task info task.dest_dc_id = dest_dc.dc_id task.dest_dc = dest_dc # compute network delay origin_loc = self.cluster_manager.get_dc_location(task.origin_dc_id) dest_loc = dest_dc.location provider = self.cluster_manager.cloud_provider # 'aws' or 'azure' size_gb = task.bandwidth_gb delay_s = get_transmission_delay(origin_loc, dest_loc, provider, size_gb) arrival_ts = self.current_time + pd.to_timedelta(delay_s, unit='s') # enqueue for later delivery dc_name = next(name for name, dc in self.cluster_manager.datacenters.items() if dc.dc_id == task.dest_dc_id) self.in_transit_tasks.append((arrival_ts, task, dc_name)) if self.logger: self.logger.info( f"[{self.current_time}] Routed task {task.job_name} from DC{task.origin_dc_id} to DC{task.dest_dc_id}, requiring a bandwidth of {task.bandwidth_gb:.2f} GB. " f"(delay={delay_s:.1f}s, will arrive at {arrival_ts})" ) # === Step all datacenters (releases, schedules, updates) === results = self.cluster_manager.step(self.current_time, logger=self.logger) # === Compute emissions and total energy === emissions_total = 0.0 energy_total = 0.0 if self.reward_fn: reward = self.reward_fn( cluster_info=results, current_tasks=self.current_tasks, current_time=self.current_time ) else: reward = 0.0 # Log the individual rewards components in the tensorboard # === TensorBoard logging === if self.writer and self.reward_fn: if hasattr(self.reward_fn, "get_last_components"): # Composite reward for name, value in self.reward_fn.get_last_components().items(): self.writer.add_scalar(f"RewardComponents/{name}", value, self.global_step) elif hasattr(self.reward_fn, "get_last_value"): # Individual reward self.writer.add_scalar(f"Reward/{str(self.reward_fn)}", self.reward_fn.get_last_value(), self.global_step) self.global_step += 1 # === Advance time by 15 minutes and load next tasks === self.current_time += pd.Timedelta(minutes=15) self._load_new_tasks() done = self.current_time >= self.end_time truncated = done obs = self._get_obs() info = { "total_energy_kwh": energy_total, "total_emissions_kg": emissions_total, "scheduled_tasks": len(actions), "datacenter_infos": results["datacenter_infos"], "transmission_cost_total_usd" : results["transmission_cost_total_usd"], } return obs, reward, done, truncated, info
def _load_new_tasks(self): """Load tasks for the current time step.""" self.current_tasks = self.deferred_tasks # first pick leftovers self.deferred_tasks = [] # Only load tasks manually if using RL agent if self.cluster_manager.strategy == "manual_rl": new_tasks = self.cluster_manager.get_tasks_for_timestep(self.current_time) self.current_tasks += new_tasks if self.logger: self.logger.info(f"[{self.current_time}] Loaded {len(new_tasks)} new tasks + {len(self.current_tasks) - len(new_tasks)} total.") else: # RBC loads and handles tasks internally self.current_tasks = [] # def _next_task(self): # if self.pending_tasks: # self.current_task = self.pending_tasks.pop(0) # else: # self.current_task = None # def _advance_time_if_needed(self): # if not self.pending_tasks: # self.current_time += self.time_step # def _check_done(self): # return self.current_time >= self.end_time def _get_obs(self): obs = [] dc_infos = [] # === Step 1: Time encoding (sine/cosine of day of year and hour) === day_of_year = self.current_time.dayofyear hour_of_day = self.current_time.hour + self.current_time.minute / 60.0 day_sin = np.sin(2 * np.pi * day_of_year / 365.0) day_cos = np.cos(2 * np.pi * day_of_year / 365.0) hour_sin = np.sin(2 * np.pi * hour_of_day / 24.0) hour_cos = np.cos(2 * np.pi * hour_of_day / 24.0) # === Step 2: Extract current prices === prices = [] for dc in self.cluster_manager.datacenters.values(): price = float(dc.price_manager.get_current_price()) / 100 # Normalize prices.append(price) prices = np.array(prices, dtype=np.float32) num_dcs = len(prices) # === Step 3: One-hot encode the cheapest DC === # cheapest_idx = int(np.argmin(prices)) # one_hot_cheapest = np.zeros(num_dcs, dtype=np.float32) # one_hot_cheapest[cheapest_idx] = 1.0 # === Step 4: Extract DC resource and sustainability info === for dc in self.cluster_manager.datacenters.values(): dc_infos.append([ dc.available_cores / dc.total_cores, dc.available_gpus / dc.total_gpus, dc.available_mem / dc.total_mem_GB, float(dc.ci_manager.get_current_ci(norm=False)/1000), # carbon intensity float(dc.price_manager.get_current_price())/100, # energy price ]) dc_state_features = [value for dc_info in dc_infos for value in dc_info] # === Step 5: Build observation per task === for task in self.current_tasks: time_to_deadline = max(0.0, (task.sla_deadline - self.current_time).total_seconds() / 60.0) task_features = [ task.origin_dc_id, task.cores_req, task.gpu_req, task.duration, time_to_deadline ] full_obs = ( [day_sin, day_cos, hour_sin, hour_cos] + task_features + dc_state_features ) obs.append(full_obs) return obs