Source code for harl.models.base.plain_mlp

import torch.nn as nn
from harl.utils.models_tools import get_active_func


[docs] class PlainMLP(nn.Module): """Plain MLP""" def __init__(self, sizes, activation_func, final_activation_func="identity"): super().__init__() layers = [] for j in range(len(sizes) - 1): act = activation_func if j < len(sizes) - 2 else final_activation_func layers += [nn.Linear(sizes[j], sizes[j + 1]), get_active_func(act)] self.mlp = nn.Sequential(*layers)
[docs] def forward(self, x): return self.mlp(x)