MindSpore自定义网络模型

# 各类网络层都在nn里面
import mindspore.nn as nn
# 参数初始化的方式
from mindspore.common.initializer import TruncatedNormal


# 再重写卷积函数
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    # 参数初始化
    weight = TruncatedNormal(0.02)
    return nn.Conv2d(in_channels, out_channels,
                     kernel_size=kernel_size, stride=stride, padding=padding,
                     weight_init=weight, has_bias=False, pad_mode="same")


# 重写全连接函数
def fc_with_initialize(input_channels, out_channels):
    # 参数初始化
    weight = TruncatedNormal(0.02)
    bias = TruncatedNormal(0.02)
    return nn.Dense(input_channels, out_channels, weight, bias)


# 定义网络
class LeNet5a(nn.Cell):
    def __init__(self, num_class=10, channel=3):
        super(LeNet5a, self).__init__()
        self.num_class = num_class
        self.conv1 = conv(channel, 6, 5)
        self.conv2 = conv(6, 16, 5)
        self.fc1 = fc_with_initialize(16 * 8 * 8, 120)
        self.fc2 = fc_with_initialize(120, 84)
        self.fc3 = fc_with_initialize(84, self.num_class)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


class LeNet5b(nn.Cell):
    def __init__(self, num_class=10, channel=3):
        super(LeNet5b, self).__init__()
        self.num_class = num_class
        self.conv1_1 = conv(channel, 8, 3)
        # 对小批量3d数据组成的4d输入进行批标准化,该期望输入的大小为'batch_size x num_features x height x width'
        self.bn2_1 = nn.BatchNorm2d(num_features=8)
        self.conv1_2 = conv(8, 16, 3)
        self.bn2_2 = nn.BatchNorm2d(num_features=16)
        self.conv2_1 = conv(16, 32, 3)
        self.bn2_3 = nn.BatchNorm2d(num_features=32)
        self.conv2_2 = conv(32, 64, 3)
        self.bn2_4 = nn.BatchNorm2d(num_features=64)
        self.fc1 = fc_with_initialize(64*8*8, 120)
        # 对小批量的2dim或3dim输入进行批标准化,该期望输入的大小为'batch_size x num_features [x width]'
        self.bn1_1 = nn.BatchNorm1d(num_features=120)
        self.fc2 = fc_with_initialize(120, 84)
        self.bn1_2 = nn.BatchNorm1d(num_features=84)
        self.fc3 = fc_with_initialize(84, self.num_class)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.conv1_1(x)
        x = self.bn2_1(x)
        x = self.relu(x)
        x = self.conv1_2(x)
        # x = self.bn2_2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2_1(x)
        # x = self.bn2_3(x)
        x = self.relu(x)
        x = self.conv2_2(x)
        # x = self.bn2_4(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        # x = self.bn1_1(x)
        x = self.relu(x)
        x = self.fc2(x)
        # x = self.bn1_2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值