import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from torchvision.models import vgg16
from model_fcn8s import VGG
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
expected_model = VGG(num_classes=21, struct=struct)
pretrained_vgg16 = vgg16(pretrained=True)
pretrained_vgg16 = pretrained_vgg16.features
for expected_param, loaded_param in zip(expected_model.parameters(), pretrained_vgg16.parameters()):
assert expected_param.shape == loaded_param.shape, "Parameter shape mismatch!"
print("Model structure matches the expected structure.")