728x90
Model Definition
class VGG11(nn.Module):
def __init__(self):
super(VGG11, self).__init__()
self.relu = nn.ReLU(inplace=True)
# Convolution Feature Extraction Part
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn3_1 = nn.BatchNorm2d(256)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn3_2 = nn.BatchNorm2d(256)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.bn4_1 = nn.BatchNorm2d(512)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn4_2 = nn.BatchNorm2d(512)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn5_1 = nn.BatchNorm2d(512)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn5_2 = nn.BatchNorm2d(512)
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.pool2(x)
x = self.conv3_1(x)
x = self.bn3_1(x)
x = self.relu(x)
x = self.conv3_2(x)
x = self.bn3_2(x)
x = self.relu(x)
x = self.pool3(x)
x = self.conv4_1(x)
x = self.bn4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.bn4_2(x)
x = self.relu(x)
x = self.pool4(x)
x = self.conv5_1(x)
x = self.bn5_1(x)
x = self.relu(x)
x = self.conv5_2(x)
x = self.bn5_2(x)
x = self.relu(x)
return x
class VGG11Classification(nn.Module):
def __init__(self, num_classes=7):
super(VGG11Classification, self).__init__()
self.backbone = VGG11()
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc_out = nn.Linear(512, num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.backbone.pool5(x)
x = self.gap(x)
x = torch.flatten(x, 1)
x = self.fc_out(x)
return x
Get the modules number of parameters
def get_module_params_num(module):
param_num = 0
for _, param in module.named_parameters():
if param.requires_grad:
param_num += param.numel()
return param_num
Get the model number of parameters
def get_model_params_num(model):
module_num = 0
for name, module in model._modules.items():
module_num += get_module_params_num(module)
return module_num
Main
model = VGG11Classification()
num_params = get_model_params_num(model)
print(num_params)
728x90