# utils/checkpoint_manager.pyimportosimporttorchimportlogging# Optional: for logging during loadlogger=logging.getLogger(__name__)# Optional: use logging
[docs]defsave_checkpoint(step,actor,critic,actor_opt,critic_opt,save_dir,is_best=False,filename=None,# Added optional filename override**kwargs# Accept extra keyword arguments (like running stats)):"""Saves model and optimizer states, plus optional extra data."""os.makedirs(save_dir,exist_ok=True)checkpoint={"step":step,"actor_state_dict":actor.state_dict(),"critic_state_dict":critic.state_dict(),"actor_optimizer_state_dict":actor_opt.state_dict(),"critic_optimizer_state_dict":critic_opt.state_dict(),}# Add any extra keyword arguments provided (e.g., running stats states)checkpoint.update(kwargs)# Determine filenameiffilenameisNone:filename="best_checkpoint.pth"ifis_bestelsef"checkpoint_step_{step}.pth"path=os.path.join(save_dir,filename)try:torch.save(checkpoint,path)print(f"Checkpoint saved at step {step} -> {path}")# Keep simple print for progressexceptExceptionase:print(f"Error saving checkpoint to {path}: {e}")
[docs]defload_checkpoint(path,actor,critic,actor_opt=None,critic_opt=None,device="cpu",# Add args to potentially receive running stats objectsreward_stats=None,critic_obs_stats=None):"""Loads model and optimizer states, and optionally running stats states."""ifnotos.path.exists(path):logger.error(f"Checkpoint file not found: {path}")return0# Return step 0 or raise errortry:checkpoint=torch.load(path,map_location=device)# Load core componentsactor.load_state_dict(checkpoint["actor_state_dict"])critic.load_state_dict(checkpoint["critic_state_dict"])logger.info(f"Loaded actor and critic state dicts from {path}")# Load optimizers if providedifactor_optand"actor_optimizer_state_dict"incheckpoint:actor_opt.load_state_dict(checkpoint["actor_optimizer_state_dict"])logger.info("Loaded actor optimizer state dict.")ifcritic_optand"critic_optimizer_state_dict"incheckpoint:critic_opt.load_state_dict(checkpoint["critic_optimizer_state_dict"])logger.info("Loaded critic optimizer state dict.")# Load running stats if provided and present in checkpointifreward_statsand"reward_stats"incheckpoint:try:reward_stats.set_state(checkpoint["reward_stats"])logger.info("Loaded reward running stats.")exceptExceptionase:logger.warning(f"Could not load reward_stats: {e}. Stats might be reset.")elifreward_stats:logger.warning("reward_stats object provided, but no reward_stats found in checkpoint.")ifcritic_obs_statsand"critic_obs_stats"incheckpoint:try:critic_obs_stats.set_state(checkpoint["critic_obs_stats"])logger.info("Loaded critic observation running stats.")exceptExceptionase:logger.warning(f"Could not load critic_obs_stats: {e}. Stats might be reset.")elifcritic_obs_stats:logger.warning("critic_obs_stats object provided, but no critic_obs_stats found in checkpoint.")loaded_step=checkpoint.get("step",0)logger.info(f"Checkpoint loaded successfully from step {loaded_step}.")returnloaded_stepexceptExceptionase:logger.error(f"Error loading checkpoint from {path}: {e}")return0# Return step 0 or raise error