"""Tools for HARL."""importcopyimportmathimporttorchimporttorch.nnasnn
[docs]definit_device(args):"""Init device. Args: args: (dict) arguments Returns: device: (torch.device) device """ifargs["cuda"]andtorch.cuda.is_available():print("choose to use gpu...")device=torch.device("cuda:0")ifargs["cuda_deterministic"]:torch.backends.cudnn.benchmark=Falsetorch.backends.cudnn.deterministic=Trueelse:print("choose to use cpu...")device=torch.device("cpu")torch.set_num_threads(args["torch_threads"])returndevice
[docs]defget_active_func(activation_func):"""Get the activation function. Args: activation_func: (str) activation function Returns: activation function: (torch.nn) activation function """ifactivation_func=="sigmoid":returnnn.Sigmoid()elifactivation_func=="tanh":returnnn.Tanh()elifactivation_func=="relu":returnnn.ReLU()elifactivation_func=="leaky_relu":returnnn.LeakyReLU()elifactivation_func=="selu":returnnn.SELU()elifactivation_func=="hardswish":returnnn.Hardswish()elifactivation_func=="identity":returnnn.Identity()else:assertFalse,"activation function not supported!"
[docs]defupdate_linear_schedule(optimizer,epoch,total_num_epochs,initial_lr):"""Decreases the learning rate linearly Args: optimizer: (torch.optim) optimizer epoch: (int) current epoch total_num_epochs: (int) total number of epochs initial_lr: (float) initial learning rate """learning_rate=initial_lr-(initial_lr*((epoch-1)/float(total_num_epochs)))forparam_groupinoptimizer.param_groups:param_group["lr"]=learning_rate