Source code for harl.models.base.plain_cnn
import torch.nn as nn
from harl.utils.models_tools import get_active_func
from harl.models.base.flatten import Flatten
[docs]
class PlainCNN(nn.Module):
"""Plain CNN"""
def __init__(
self, obs_shape, hidden_size, activation_func, kernel_size=3, stride=1
):
super().__init__()
input_channel = obs_shape[0]
input_width = obs_shape[1]
input_height = obs_shape[2]
layers = [
nn.Conv2d(
in_channels=input_channel,
out_channels=hidden_size // 4,
kernel_size=kernel_size,
stride=stride,
),
get_active_func(activation_func),
Flatten(),
nn.Linear(
hidden_size
// 4
* (input_width - kernel_size + stride)
* (input_height - kernel_size + stride),
hidden_size,
),
get_active_func(activation_func),
]
self.cnn = nn.Sequential(*layers)
[docs]
def forward(self, x):
x = x / 255.0
x = self.cnn(x)
return x