ShuffleNet图像分类
当前案例不支持在GPU设备上静态图模式运行,其他模式运行皆支持。
ShuffleNet网络介绍
ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。
了解ShuffleNet更多详细内容,详见论文ShuffleNet。
如下图所示,ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。
图片来源:Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.
模型架构
ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。
Pointwise Group Convolution
Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有(in_channels/g*k*k)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量。
图片来源:Huang G, Liu S, Van der Maaten L, et al. Condensenet: An efficient densenet using learned group convolutions[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 2752-2761.
Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels
,然后对每一个in_channels
做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1*k*k,则卷积核参数量为:in_channels*k*k,得到的feature maps通道数与输入通道数相等;
Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1 × 1 1\times 1 1×1,卷积核参数量为(in_channels/g*1*1)*out_channels。
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by:
from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensor
class GroupConv(nn.Cell):
def __init__(self, in_channels, out_channels, kernel_size,
stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
super(GroupConv, self).__init__()
self.groups = groups
self.convs = nn.CellList()
for _ in range(groups):
self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
kernel_size=kernel_size, stride=stride, has_bias=has_bias,
padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))
def construct(self, x):
features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
outputs = ()
for i in range(self.groups):
outputs = outputs + (self.convs[i](features[i].astype("float32")),)
out = ops.cat(outputs, axis=1)
return out
Channel Shuffle
Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。
为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。
如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。
为了阅读方便,将Channel Shuffle的代码实现放在下方ShuffleNet模块的代码中。
ShuffleNet模块
如下图所示,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), ©的更改:
-
将开始和最后的 1 × 1 1\times 1 1×1卷积模块(降维、升维)改成Point Wise Group Convolution;
-
为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;
-
降采样模块中, 3 × 3 3 \times 3 3×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的 3 × 3 3\times 3 3×3平均池化,并把相加改成拼接。
class ShuffleV1Block(nn.Cell):
def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
super(ShuffleV1Block, self).__init__()
self.stride = stride
pad = ksize // 2
self.group = group
if stride == 2:
outputs = oup - inp
else:
outputs = oup
self.relu = nn.ReLU()
branch_main_1 = [
GroupConv(in_channels=inp, out_channels=mid_channels,
kernel_size=1, stride=1, pad_mode="pad", pad=0,
groups=1 if first_group else group),
nn.BatchNorm2d(mid_channels),
nn.ReLU(),
]
branch_main_2 = [
nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=mid_channels,
weight_init='xavier_uniform', has_bias=False),
nn.BatchNorm2d(mid_channels),
GroupConv(in_channels=mid_channels, out_channels=outputs,
kernel_size=1, stride=1, pad_mode="pad", pad=0,
groups=group),
nn.BatchNorm2d(outputs),
]
self.branch_main_1 = nn.SequentialCell(branch_main_1)
self.branch_main_2 = nn.SequentialCell(branch_main_2)
if stride == 2:
self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')
def construct(self, old_x):
left = old_x
right = old_x
out = old_x
right = self.branch_main_1(right)
if self.group > 1:
right = self.channel_shuffle(right)
right = self.branch_main_2(right)
if self.stride == 1:
out = self.relu(left + right)
elif self.stride == 2:
left = self.branch_proj(left)
out = ops.cat((left, right), 1)
out = self.relu(out)
return out
def channel_shuffle(self, x):
batchsize, num_channels, height, width = ops.shape(x)
group_channels = num_channels // self.group
x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
x = ops.transpose(x, (0, 2, 1, 3, 4))
x = ops.reshape(x, (batchsize, num_channels, height, width))
return x
构建ShuffleNet网络
ShuffleNet网络结构如下图所示,以输入图像 224 × 224 224 \times 224 224×224,组数3(g = 3)为例,首先通过数量24,卷积核大小为 3 × 3 3 \times 3 3×3,stride为2的卷积层,输出特征图大小为 112 × 112 112 \times 112 112×112,channel为24;然后通过stride为2的最大池化层,输出特征图大小为 56 × 56 56 \times 56 56×56,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图©),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为 1 × 1 × 960 1 \times 1 \times 960 1×1×960,再经过全连接层和softmax,得到分类概率。
class ShuffleNetV1(nn.Cell):
def __init__(self, n_class=1000, model_size='2.0x', group=3):
super(ShuffleNetV1, self).__init__()
print('model size is ', model_size)
self.stage_repeats = [4, 8, 4]
self.model_size = model_size
if group == 3:
if model_size == '0.5x':
self.stage_out_channels = [-1, 12, 120, 240, 480]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 240, 480, 960]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 360, 720, 1440]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 48, 480, 960, 1920]
else:
raise NotImplementedError
elif group == 8:
if model_size == '0.5x':
self.stage_out_channels = [-1, 16, 192, 384, 768]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 384, 768, 1536]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 576, 1152, 2304]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 48, 768, 1536, 3072]
else:
raise NotImplementedError
input_channel = self.stage_out_channels[1]
self.first_conv = nn.SequentialCell(
nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
nn.BatchNorm2d(input_channel),
nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
features = []
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage + 2]
for i in range(numrepeat):
stride = 2 if i == 0 else 1
first_group = idxstage == 0 and i == 0
features.append(ShuffleV1Block(input_channel, output_channel,
group=group, first_group=first_group,
mid_channels=output_channel // 4, ksize=3, stride=stride))
input_channel = output_channel
self.features = nn.SequentialCell(features)
self.globalpool = nn.AvgPool2d(7)
self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)
def construct(self, x):
x = self.first_conv(x)
x = self.maxpool(x)
x = self.features(x)
x = self.globalpool(x)
x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
x = self.classifier(x)
return x
模型训练和评估
采用CIFAR-10数据集对ShuffleNet进行预训练。
训练集准备与加载
采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。如下示例使用mindspore.dataset.Cifar10Dataset
接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
download(url, "./dataset", kind="tar.gz", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz (162.2 MB)
file_sizes: 100%|████████████████████████████| 170M/170M [00:01<00:00, 93.7MB/s]
Extracting tar.gz file...
Successfully downloaded / unzipped to ./dataset
'./dataset'
import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms
def get_dataset(train_dataset_path, batch_size, usage):
image_trans = []
if usage == "train":
image_trans = [
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Resize((224, 224)),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
vision.HWC2CHW()
]
elif usage == "test":
image_trans = [
vision.Resize((224, 224)),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
vision.HWC2CHW()
]
label_trans = transforms.TypeCast(ms.int32)
dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)
dataset = dataset.map(image_trans, 'image')
dataset = dataset.map(label_trans, 'label')
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
batches_per_epoch = dataset.get_dataset_size()
模型训练
本节用随机初始化的参数做预训练。首先调用ShuffleNetV1
定义网络,参数量选择"2.0x"
,并定义损失函数为交叉熵损失,学习率经过4轮的warmup
后采用余弦退火,优化器采用Momentum
。最后用train.model
中的Model
接口将模型、损失函数、优化器封装在model
中,并用model.train()
对网络进行训练。将ModelCheckpoint
、CheckpointConfig
、TimeMonitor
和LossMonitor
传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。
import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy
def train():
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
net = ShuffleNetV1(model_size="2.0x", n_class=10)
loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
min_lr = 0.0005
base_lr = 0.05
lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
base_lr,
batches_per_epoch*250,
batches_per_epoch,
decay_epoch=250)
lr = Tensor(lr_scheduler[-1])
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
callback = [TimeMonitor(), LossMonitor()]
save_ckpt_path = "./"
config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
callback += [ckpt_callback]
print("============== Starting Training ==============")
start_time = time.time()
# 由于时间原因,epoch = 5,可根据需求进行调整
model.train(5, dataset, callbacks=callback)
use_time = time.time() - start_time
hour = str(int(use_time // 60 // 60))
minute = str(int(use_time // 60 % 60))
second = str(int(use_time % 60))
print("total time:" + hour + "h " + minute + "m " + second + "s")
print("============== Train Success ==============")
if __name__ == '__main__':
train()
model size is 2.0x
============== Starting Training ==============
epoch: 1 step: 1, loss is 2.637444257736206
epoch: 1 step: 2, loss is 2.5519354343414307
epoch: 1 step: 3, loss is 2.417301893234253
epoch: 1 step: 4, loss is 2.41119122505188
epoch: 1 step: 5, loss is 2.4165942668914795
epoch: 1 step: 6, loss is 2.3340768814086914
epoch: 1 step: 7, loss is 2.5514841079711914
epoch: 1 step: 8, loss is 2.5256190299987793
epoch: 1 step: 9, loss is 2.32163405418396
epoch: 1 step: 10, loss is 2.248215913772583
epoch: 1 step: 11, loss is 2.3002939224243164
epoch: 1 step: 12, loss is 2.3838281631469727
epoch: 1 step: 13, loss is 2.396728515625
epoch: 1 step: 14, loss is 2.373441457748413
epoch: 1 step: 15, loss is 2.338778495788574
epoch: 1 step: 16, loss is 2.266740322113037
epoch: 1 step: 17, loss is 2.257406234741211
epoch: 1 step: 18, loss is 2.2662010192871094
epoch: 1 step: 19, loss is 2.285461187362671
epoch: 1 step: 20, loss is 2.370849370956421
epoch: 1 step: 21, loss is 2.2432849407196045
epoch: 1 step: 22, loss is 2.2536284923553467
epoch: 1 step: 23, loss is 2.2301814556121826
epoch: 1 step: 24, loss is 2.210026264190674
epoch: 1 step: 25, loss is 2.1990232467651367
epoch: 1 step: 26, loss is 2.332028865814209
epoch: 1 step: 27, loss is 2.3111772537231445
epoch: 1 step: 28, loss is 2.2818078994750977
epoch: 1 step: 29, loss is 2.233259677886963
epoch: 1 step: 30, loss is 2.286853075027466
epoch: 1 step: 31, loss is 2.259453773498535
epoch: 1 step: 32, loss is 2.252955675125122
epoch: 1 step: 33, loss is 2.2320759296417236
epoch: 1 step: 34, loss is 2.271742105484009
epoch: 1 step: 35, loss is 2.17651629447937
epoch: 1 step: 36, loss is 2.1235527992248535
epoch: 1 step: 37, loss is 2.1571192741394043
epoch: 1 step: 38, loss is 2.2133190631866455
epoch: 1 step: 39, loss is 2.309330940246582
epoch: 1 step: 40, loss is 2.1286849975585938
epoch: 1 step: 41, loss is 2.1421730518341064
epoch: 1 step: 42, loss is 2.2582814693450928
epoch: 1 step: 43, loss is 2.2963032722473145
epoch: 1 step: 44, loss is 2.141162633895874
epoch: 1 step: 45, loss is 2.18499755859375
epoch: 1 step: 46, loss is 2.13185977935791
epoch: 1 step: 47, loss is 2.1182143688201904
epoch: 1 step: 48, loss is 2.1624720096588135
epoch: 1 step: 49, loss is 2.122936248779297
epoch: 1 step: 50, loss is 2.094850540161133
epoch: 1 step: 51, loss is 2.147307872772217
epoch: 1 step: 52, loss is 2.159871816635132
epoch: 1 step: 53, loss is 2.2226274013519287
epoch: 1 step: 54, loss is 2.1584930419921875
epoch: 1 step: 55, loss is 2.101642370223999
epoch: 1 step: 56, loss is 2.062920093536377
epoch: 1 step: 57, loss is 2.146536350250244
epoch: 1 step: 58, loss is 2.1339383125305176
epoch: 1 step: 59, loss is 2.174254894256592
epoch: 1 step: 60, loss is 2.0403919219970703
epoch: 1 step: 61, loss is 2.1350648403167725
epoch: 1 step: 62, loss is 2.0602900981903076
epoch: 1 step: 63, loss is 2.0428242683410645
epoch: 1 step: 64, loss is 2.1254398822784424
epoch: 1 step: 65, loss is 2.1450183391571045
epoch: 1 step: 66, loss is 2.128354787826538
epoch: 1 step: 67, loss is 2.1080520153045654
epoch: 1 step: 68, loss is 2.110933542251587
epoch: 1 step: 69, loss is 2.163707733154297
epoch: 1 step: 70, loss is 2.171631097793579
epoch: 1 step: 71, loss is 2.0731008052825928
epoch: 1 step: 72, loss is 2.1859469413757324
epoch: 1 step: 73, loss is 1.951612949371338
epoch: 1 step: 74, loss is 2.036221981048584
epoch: 1 step: 75, loss is 2.0642924308776855
epoch: 1 step: 76, loss is 2.019010543823242
epoch: 1 step: 77, loss is 2.1194803714752197
epoch: 1 step: 78, loss is 2.0619945526123047
epoch: 1 step: 79, loss is 1.9623534679412842
epoch: 1 step: 80, loss is 2.0492002964019775
epoch: 1 step: 81, loss is 2.0460927486419678
epoch: 1 step: 82, loss is 2.0349040031433105
epoch: 1 step: 83, loss is 2.1013803482055664
epoch: 1 step: 84, loss is 1.9986183643341064
epoch: 1 step: 85, loss is 2.0838303565979004
epoch: 1 step: 86, loss is 2.019192934036255
epoch: 1 step: 87, loss is 2.0144875049591064
epoch: 1 step: 88, loss is 2.000406265258789
epoch: 1 step: 89, loss is 2.110612630844116
epoch: 1 step: 90, loss is 2.058392286300659
epoch: 1 step: 91, loss is 2.0305466651916504
epoch: 1 step: 92, loss is 1.9904029369354248
epoch: 1 step: 93, loss is 2.030336856842041
epoch: 1 step: 94, loss is 1.971001386642456
epoch: 1 step: 95, loss is 2.0079398155212402
epoch: 1 step: 96, loss is 2.030968427658081
epoch: 1 step: 97, loss is 1.9490876197814941
epoch: 1 step: 98, loss is 1.9908173084259033
epoch: 1 step: 99, loss is 2.0175390243530273
epoch: 1 step: 100, loss is 2.003598690032959
epoch: 1 step: 101, loss is 2.0132622718811035
epoch: 1 step: 102, loss is 2.067610263824463
epoch: 1 step: 103, loss is 1.9815086126327515
epoch: 1 step: 104, loss is 1.9610518217086792
epoch: 1 step: 105, loss is 2.0739641189575195
epoch: 1 step: 106, loss is 2.082470655441284
epoch: 1 step: 107, loss is 2.033511161804199
epoch: 1 step: 108, loss is 2.0472147464752197
epoch: 1 step: 109, loss is 2.0449018478393555
epoch: 1 step: 110, loss is 2.0509555339813232
epoch: 1 step: 111, loss is 2.0655226707458496
epoch: 1 step: 112, loss is 1.9426164627075195
epoch: 1 step: 113, loss is 2.052448272705078
epoch: 1 step: 114, loss is 2.0476441383361816
epoch: 1 step: 115, loss is 1.9755520820617676
epoch: 1 step: 116, loss is 1.9978209733963013
epoch: 1 step: 117, loss is 2.0231034755706787
epoch: 1 step: 118, loss is 1.9953503608703613
epoch: 1 step: 119, loss is 2.026078462600708
epoch: 1 step: 120, loss is 1.9307632446289062
epoch: 1 step: 121, loss is 2.0748066902160645
epoch: 1 step: 122, loss is 2.011425018310547
epoch: 1 step: 123, loss is 1.952304482460022
epoch: 1 step: 124, loss is 2.1205146312713623
epoch: 1 step: 125, loss is 2.0550355911254883
epoch: 1 step: 126, loss is 2.0679333209991455
epoch: 1 step: 127, loss is 2.0083532333374023
epoch: 1 step: 128, loss is 2.0059330463409424
epoch: 1 step: 129, loss is 2.0808990001678467
epoch: 1 step: 130, loss is 2.0757577419281006
epoch: 1 step: 131, loss is 2.000009536743164
epoch: 1 step: 132, loss is 2.0516088008880615
epoch: 1 step: 133, loss is 2.1288304328918457
epoch: 1 step: 134, loss is 1.980472445487976
epoch: 1 step: 135, loss is 1.9199179410934448
epoch: 1 step: 136, loss is 2.0529212951660156
epoch: 1 step: 137, loss is 1.9572110176086426
epoch: 1 step: 138, loss is 2.0123071670532227
epoch: 1 step: 139, loss is 1.859535813331604
epoch: 1 step: 140, loss is 1.9985237121582031
epoch: 1 step: 141, loss is 1.9810898303985596
epoch: 1 step: 142, loss is 1.9665652513504028
epoch: 1 step: 143, loss is 1.9929875135421753
epoch: 1 step: 144, loss is 1.931304693222046
epoch: 1 step: 145, loss is 2.0240116119384766
epoch: 1 step: 146, loss is 1.9483182430267334
epoch: 1 step: 147, loss is 1.9971871376037598
epoch: 1 step: 148, loss is 2.077157497406006
epoch: 1 step: 149, loss is 1.9434341192245483
epoch: 1 step: 150, loss is 1.9218116998672485
epoch: 1 step: 151, loss is 1.9227429628372192
epoch: 1 step: 152, loss is 2.0026254653930664
epoch: 1 step: 153, loss is 1.9405245780944824
epoch: 1 step: 154, loss is 1.9563802480697632
epoch: 1 step: 155, loss is 1.9225423336029053
epoch: 1 step: 156, loss is 2.1118626594543457
epoch: 1 step: 157, loss is 2.084660530090332
epoch: 1 step: 158, loss is 2.031627655029297
epoch: 1 step: 159, loss is 2.022655725479126
epoch: 1 step: 160, loss is 1.8749659061431885
epoch: 1 step: 161, loss is 1.9071351289749146
epoch: 1 step: 162, loss is 2.0629053115844727
epoch: 1 step: 163, loss is 1.9555968046188354
epoch: 1 step: 164, loss is 1.9368162155151367
epoch: 1 step: 165, loss is 2.057159423828125
epoch: 1 step: 166, loss is 2.006042003631592
epoch: 1 step: 167, loss is 1.939833641052246
epoch: 1 step: 168, loss is 2.0822463035583496
epoch: 1 step: 169, loss is 1.9786700010299683
epoch: 1 step: 170, loss is 2.0766124725341797
epoch: 1 step: 171, loss is 1.9605066776275635
epoch: 1 step: 172, loss is 2.059980630874634
epoch: 1 step: 173, loss is 1.9184414148330688
epoch: 1 step: 174, loss is 1.9574657678604126
epoch: 1 step: 175, loss is 2.1406052112579346
epoch: 1 step: 176, loss is 1.9175764322280884
epoch: 1 step: 177, loss is 2.095923900604248
epoch: 1 step: 178, loss is 1.9527654647827148
epoch: 1 step: 179, loss is 1.9658640623092651
epoch: 1 step: 180, loss is 2.0261096954345703
epoch: 1 step: 181, loss is 1.9398032426834106
epoch: 1 step: 182, loss is 1.924546480178833
epoch: 1 step: 183, loss is 2.087576389312744
epoch: 1 step: 184, loss is 2.0642359256744385
epoch: 1 step: 185, loss is 2.016495943069458
epoch: 1 step: 186, loss is 1.9951825141906738
epoch: 1 step: 187, loss is 1.964837908744812
epoch: 1 step: 188, loss is 2.0414724349975586
epoch: 1 step: 189, loss is 1.8743633031845093
epoch: 1 step: 190, loss is 1.9245572090148926
epoch: 1 step: 191, loss is 2.0721426010131836
epoch: 1 step: 192, loss is 2.003959894180298
epoch: 1 step: 193, loss is 1.9853498935699463
epoch: 1 step: 194, loss is 1.9460020065307617
epoch: 1 step: 195, loss is 1.9167520999908447
epoch: 1 step: 196, loss is 2.0020596981048584
epoch: 1 step: 197, loss is 2.049525499343872
epoch: 1 step: 198, loss is 2.053671360015869
epoch: 1 step: 199, loss is 1.8961541652679443
epoch: 1 step: 200, loss is 1.9797005653381348
epoch: 1 step: 201, loss is 1.9717200994491577
epoch: 1 step: 202, loss is 1.9439831972122192
epoch: 1 step: 203, loss is 2.043377161026001
epoch: 1 step: 204, loss is 2.003532886505127
epoch: 1 step: 205, loss is 1.9102176427841187
epoch: 1 step: 206, loss is 2.00426983833313
epoch: 1 step: 207, loss is 1.8981976509094238
epoch: 1 step: 208, loss is 1.9090827703475952
epoch: 1 step: 209, loss is 1.924512267112732
epoch: 1 step: 210, loss is 1.9105455875396729
epoch: 1 step: 211, loss is 1.921006441116333
epoch: 1 step: 212, loss is 1.9841300249099731
epoch: 1 step: 213, loss is 2.0343875885009766
epoch: 1 step: 214, loss is 1.987828254699707
epoch: 1 step: 215, loss is 1.9013984203338623
epoch: 1 step: 216, loss is 1.8945801258087158
epoch: 1 step: 217, loss is 1.941464900970459
epoch: 1 step: 218, loss is 1.9070900678634644
epoch: 1 step: 219, loss is 1.9665577411651611
epoch: 1 step: 220, loss is 1.888040542602539
epoch: 1 step: 221, loss is 1.9154324531555176
epoch: 1 step: 222, loss is 2.00649094581604
epoch: 1 step: 223, loss is 2.015157699584961
epoch: 1 step: 224, loss is 1.978245735168457
epoch: 1 step: 225, loss is 1.8778069019317627
epoch: 1 step: 226, loss is 2.0120866298675537
epoch: 1 step: 227, loss is 1.9353729486465454
epoch: 1 step: 228, loss is 1.9539164304733276
epoch: 1 step: 229, loss is 1.9758312702178955
epoch: 1 step: 230, loss is 1.9771440029144287
epoch: 1 step: 231, loss is 1.9934016466140747
epoch: 1 step: 232, loss is 1.9074777364730835
epoch: 1 step: 233, loss is 1.957679033279419
epoch: 1 step: 234, loss is 1.9348262548446655
epoch: 1 step: 235, loss is 1.891913652420044
epoch: 1 step: 236, loss is 1.9891177415847778
epoch: 1 step: 237, loss is 1.9034689664840698
epoch: 1 step: 238, loss is 1.8976025581359863
epoch: 1 step: 239, loss is 2.0854296684265137
epoch: 1 step: 240, loss is 1.9493427276611328
epoch: 1 step: 241, loss is 1.8900543451309204
epoch: 1 step: 242, loss is 1.963498830795288
epoch: 1 step: 243, loss is 1.8827524185180664
epoch: 1 step: 244, loss is 1.9306913614273071
epoch: 1 step: 245, loss is 1.8937281370162964
epoch: 1 step: 246, loss is 1.932215929031372
epoch: 1 step: 247, loss is 1.9226175546646118
epoch: 1 step: 248, loss is 2.104743003845215
epoch: 1 step: 249, loss is 1.9672844409942627
epoch: 1 step: 250, loss is 1.9009824991226196
epoch: 1 step: 251, loss is 1.9491523504257202
epoch: 1 step: 252, loss is 2.0299384593963623
epoch: 1 step: 253, loss is 1.9272692203521729
epoch: 1 step: 254, loss is 1.9573429822921753
epoch: 1 step: 255, loss is 1.9414104223251343
epoch: 1 step: 256, loss is 1.9749751091003418
epoch: 1 step: 257, loss is 1.9846627712249756
epoch: 1 step: 258, loss is 1.9152398109436035
epoch: 1 step: 259, loss is 1.9870069026947021
epoch: 1 step: 260, loss is 1.8829617500305176
epoch: 1 step: 261, loss is 2.0034732818603516
epoch: 1 step: 262, loss is 1.8597939014434814
epoch: 1 step: 263, loss is 1.9629346132278442
epoch: 1 step: 264, loss is 1.8727396726608276
epoch: 1 step: 265, loss is 1.8768718242645264
epoch: 1 step: 266, loss is 2.029008388519287
epoch: 1 step: 267, loss is 2.0169904232025146
epoch: 1 step: 268, loss is 1.9506852626800537
epoch: 1 step: 269, loss is 1.909484624862671
epoch: 1 step: 270, loss is 2.0111072063446045
epoch: 1 step: 271, loss is 2.0693612098693848
epoch: 1 step: 272, loss is 1.9460484981536865
epoch: 1 step: 273, loss is 1.8770350217819214
epoch: 1 step: 274, loss is 1.9787135124206543
epoch: 1 step: 275, loss is 1.9805867671966553
epoch: 1 step: 276, loss is 2.007831573486328
epoch: 1 step: 277, loss is 1.983500599861145
epoch: 1 step: 278, loss is 1.9957048892974854
epoch: 1 step: 279, loss is 2.001767158508301
epoch: 1 step: 280, loss is 1.9380371570587158
epoch: 1 step: 281, loss is 2.016770839691162
epoch: 1 step: 282, loss is 2.0466959476470947
epoch: 1 step: 283, loss is 1.9340555667877197
epoch: 1 step: 284, loss is 1.9068975448608398
epoch: 1 step: 285, loss is 1.9217978715896606
epoch: 1 step: 286, loss is 1.9618436098098755
epoch: 1 step: 287, loss is 1.9785856008529663
epoch: 1 step: 288, loss is 1.9123224020004272
epoch: 1 step: 289, loss is 1.8668341636657715
epoch: 1 step: 290, loss is 2.049499750137329
epoch: 1 step: 291, loss is 1.8257873058319092
epoch: 1 step: 292, loss is 1.901406168937683
epoch: 1 step: 293, loss is 1.8525774478912354
epoch: 1 step: 294, loss is 1.9189528226852417
epoch: 1 step: 295, loss is 1.8322581052780151
epoch: 1 step: 296, loss is 1.971509337425232
epoch: 1 step: 297, loss is 1.950898289680481
epoch: 1 step: 298, loss is 1.8920232057571411
epoch: 1 step: 299, loss is 2.0213394165039062
epoch: 1 step: 300, loss is 1.8809759616851807
epoch: 1 step: 301, loss is 1.9179118871688843
epoch: 1 step: 302, loss is 1.8198785781860352
epoch: 1 step: 303, loss is 1.8565402030944824
epoch: 1 step: 304, loss is 1.8512489795684814
epoch: 1 step: 305, loss is 1.9837815761566162
epoch: 1 step: 306, loss is 1.8517332077026367
epoch: 1 step: 307, loss is 1.8945205211639404
epoch: 1 step: 308, loss is 1.9475098848342896
epoch: 1 step: 309, loss is 1.9193038940429688
epoch: 1 step: 310, loss is 1.8788681030273438
epoch: 1 step: 311, loss is 1.869295358657837
epoch: 1 step: 312, loss is 1.914243459701538
epoch: 1 step: 313, loss is 1.8895680904388428
epoch: 1 step: 314, loss is 1.8317034244537354
epoch: 1 step: 315, loss is 1.9353530406951904
epoch: 1 step: 316, loss is 1.8667919635772705
epoch: 1 step: 317, loss is 1.8767848014831543
epoch: 1 step: 318, loss is 1.8771603107452393
epoch: 1 step: 319, loss is 1.9101288318634033
epoch: 1 step: 320, loss is 1.9494779109954834
epoch: 1 step: 321, loss is 1.979146122932434
epoch: 1 step: 322, loss is 1.8751119375228882
epoch: 1 step: 323, loss is 1.8926434516906738
epoch: 1 step: 324, loss is 1.8737379312515259
epoch: 1 step: 325, loss is 1.91633939743042
epoch: 1 step: 326, loss is 1.892572283744812
epoch: 1 step: 327, loss is 1.8479312658309937
epoch: 1 step: 328, loss is 1.9368888139724731
epoch: 1 step: 329, loss is 1.985150694847107
epoch: 1 step: 330, loss is 1.8574414253234863
epoch: 1 step: 331, loss is 2.007861614227295
epoch: 1 step: 332, loss is 1.8954650163650513
epoch: 1 step: 333, loss is 1.7576824426651
epoch: 1 step: 334, loss is 1.9236505031585693
epoch: 1 step: 335, loss is 1.8303214311599731
epoch: 1 step: 336, loss is 1.9429856538772583
epoch: 1 step: 337, loss is 1.935878872871399
epoch: 1 step: 338, loss is 1.8800196647644043
epoch: 1 step: 339, loss is 1.8589588403701782
epoch: 1 step: 340, loss is 1.8783851861953735
epoch: 1 step: 341, loss is 1.904998779296875
epoch: 1 step: 342, loss is 1.8523179292678833
epoch: 1 step: 343, loss is 1.8451892137527466
epoch: 1 step: 344, loss is 1.8797779083251953
epoch: 1 step: 345, loss is 1.8365544080734253
epoch: 1 step: 346, loss is 1.9472622871398926
epoch: 1 step: 347, loss is 1.783460259437561
epoch: 1 step: 348, loss is 1.9557878971099854
epoch: 1 step: 349, loss is 1.9549705982208252
epoch: 1 step: 350, loss is 1.910333275794983
epoch: 1 step: 351, loss is 2.056870698928833
epoch: 1 step: 352, loss is 1.8559565544128418
epoch: 1 step: 353, loss is 1.849716067314148
epoch: 1 step: 354, loss is 1.8249409198760986
epoch: 1 step: 355, loss is 1.8520405292510986
epoch: 1 step: 356, loss is 1.923189640045166
epoch: 1 step: 357, loss is 1.8970928192138672
epoch: 1 step: 358, loss is 2.0142710208892822
epoch: 1 step: 359, loss is 1.8832510709762573
epoch: 1 step: 360, loss is 1.817199468612671
epoch: 1 step: 361, loss is 1.988936185836792
epoch: 1 step: 362, loss is 1.8172954320907593
epoch: 1 step: 363, loss is 1.8937242031097412
epoch: 1 step: 364, loss is 1.9423609972000122
epoch: 1 step: 365, loss is 1.8992043733596802
epoch: 1 step: 366, loss is 1.8474411964416504
epoch: 1 step: 367, loss is 1.8042991161346436
epoch: 1 step: 368, loss is 1.8353846073150635
epoch: 1 step: 369, loss is 1.8351342678070068
epoch: 1 step: 370, loss is 1.8681347370147705
epoch: 1 step: 371, loss is 1.8676278591156006
epoch: 1 step: 372, loss is 1.8536510467529297
epoch: 1 step: 373, loss is 1.8018980026245117
epoch: 1 step: 374, loss is 1.8365895748138428
epoch: 1 step: 375, loss is 1.9459197521209717
epoch: 1 step: 376, loss is 1.8942005634307861
epoch: 1 step: 377, loss is 1.8253778219223022
epoch: 1 step: 378, loss is 1.862713098526001
epoch: 1 step: 379, loss is 1.926077127456665
epoch: 1 step: 380, loss is 1.9486587047576904
epoch: 1 step: 381, loss is 1.8932850360870361
epoch: 1 step: 382, loss is 1.8795984983444214
epoch: 1 step: 383, loss is 1.7956209182739258
epoch: 1 step: 384, loss is 1.9587161540985107
epoch: 1 step: 385, loss is 1.9208011627197266
epoch: 1 step: 386, loss is 1.9988421201705933
epoch: 1 step: 387, loss is 1.892421007156372
epoch: 1 step: 388, loss is 1.7995481491088867
epoch: 1 step: 389, loss is 1.9330435991287231
epoch: 1 step: 390, loss is 1.897336721420288
epoch: 2 step: 165, loss is 1.8130674362182617
epoch: 2 step: 166, loss is 1.8186718225479126
epoch: 2 step: 167, loss is 1.7915493249893188
epoch: 2 step: 168, loss is 1.7617254257202148
epoch: 2 step: 169, loss is 1.8228213787078857
epoch: 2 step: 170, loss is 1.7954456806182861
epoch: 2 step: 171, loss is 1.840272068977356
epoch: 2 step: 172, loss is 1.8518497943878174
epoch: 2 step: 173, loss is 1.8658790588378906
epoch: 2 step: 174, loss is 1.7962815761566162
epoch: 2 step: 175, loss is 1.8879177570343018
epoch: 2 step: 176, loss is 1.8493015766143799
epoch: 2 step: 177, loss is 1.807173252105713
epoch: 2 step: 178, loss is 1.8115954399108887
epoch: 2 step: 179, loss is 1.7684533596038818
epoch: 2 step: 180, loss is 1.8132580518722534
epoch: 2 step: 181, loss is 1.8477783203125
epoch: 2 step: 182, loss is 1.9814928770065308
epoch: 2 step: 183, loss is 1.8754057884216309
epoch: 2 step: 184, loss is 1.7704211473464966
epoch: 2 step: 185, loss is 2.013533592224121
epoch: 2 step: 186, loss is 1.8321852684020996
epoch: 2 step: 187, loss is 1.8288037776947021
epoch: 2 step: 188, loss is 1.8278642892837524
epoch: 2 step: 189, loss is 1.7732826471328735
epoch: 2 step: 190, loss is 1.9314385652542114
epoch: 2 step: 191, loss is 1.9096839427947998
epoch: 2 step: 192, loss is 1.8106887340545654
epoch: 2 step: 193, loss is 1.8480473756790161
epoch: 2 step: 194, loss is 1.8274892568588257
epoch: 2 step: 195, loss is 1.7898247241973877
epoch: 2 step: 196, loss is 1.8128597736358643
epoch: 2 step: 197, loss is 1.8097501993179321
epoch: 2 step: 198, loss is 1.8818532228469849
epoch: 2 step: 199, loss is 1.8435938358306885
epoch: 2 step: 200, loss is 1.6888142824172974
epoch: 2 step: 201, loss is 1.733601450920105
epoch: 2 step: 202, loss is 1.9149181842803955
epoch: 2 step: 203, loss is 1.7859376668930054
epoch: 2 step: 204, loss is 1.8781681060791016
epoch: 2 step: 205, loss is 1.8167052268981934
epoch: 2 step: 206, loss is 1.6948245763778687
epoch: 2 step: 207, loss is 1.8925691843032837
epoch: 2 step: 208, loss is 1.824390172958374
epoch: 2 step: 209, loss is 1.861390471458435
epoch: 2 step: 210, loss is 1.760713815689087
epoch: 2 step: 211, loss is 1.8318405151367188
epoch: 2 step: 212, loss is 1.8412034511566162
epoch: 2 step: 213, loss is 1.8712043762207031
epoch: 2 step: 214, loss is 1.8271808624267578
epoch: 2 step: 215, loss is 1.7555245161056519
epoch: 2 step: 216, loss is 1.8896197080612183
epoch: 2 step: 217, loss is 1.8013564348220825
epoch: 2 step: 218, loss is 1.8132814168930054
epoch: 2 step: 219, loss is 1.8457341194152832
epoch: 2 step: 220, loss is 1.7321116924285889
epoch: 2 step: 221, loss is 1.8054625988006592
epoch: 2 step: 222, loss is 1.82883620262146
epoch: 2 step: 223, loss is 1.7912166118621826
epoch: 2 step: 224, loss is 1.717684268951416
epoch: 2 step: 225, loss is 1.8984098434448242
epoch: 2 step: 226, loss is 1.795362949371338
epoch: 2 step: 227, loss is 1.8655365705490112
epoch: 2 step: 228, loss is 1.7815207242965698
epoch: 2 step: 229, loss is 1.8507623672485352
epoch: 2 step: 230, loss is 1.6953002214431763
epoch: 2 step: 231, loss is 1.8273009061813354
epoch: 2 step: 232, loss is 1.6717973947525024
epoch: 2 step: 233, loss is 1.7421903610229492
epoch: 2 step: 234, loss is 1.7551933526992798
epoch: 2 step: 235, loss is 1.9098342657089233
epoch: 2 step: 236, loss is 1.7976480722427368
epoch: 2 step: 237, loss is 1.7673227787017822
epoch: 2 step: 238, loss is 1.802810788154602
epoch: 2 step: 239, loss is 1.7880874872207642
epoch: 2 step: 240, loss is 1.7560418844223022
epoch: 2 step: 241, loss is 1.7605935335159302
epoch: 2 step: 242, loss is 1.7119107246398926
epoch: 2 step: 243, loss is 1.647666335105896
epoch: 2 step: 244, loss is 1.8351585865020752
epoch: 2 step: 245, loss is 1.8196346759796143
epoch: 2 step: 246, loss is 1.7705540657043457
epoch: 2 step: 247, loss is 1.7069883346557617
epoch: 2 step: 248, loss is 1.8272579908370972
epoch: 2 step: 249, loss is 1.7012746334075928
epoch: 2 step: 250, loss is 1.7216856479644775
epoch: 2 step: 251, loss is 1.6781213283538818
epoch: 2 step: 252, loss is 1.889383316040039
epoch: 2 step: 253, loss is 1.799215316772461
epoch: 2 step: 254, loss is 1.698905110359192
epoch: 2 step: 255, loss is 1.8454126119613647
epoch: 2 step: 256, loss is 1.803262710571289
epoch: 2 step: 257, loss is 1.8735960721969604
epoch: 2 step: 258, loss is 1.9001646041870117
epoch: 2 step: 259, loss is 1.8101235628128052
epoch: 2 step: 260, loss is 1.7719933986663818
epoch: 2 step: 261, loss is 1.7371867895126343
epoch: 2 step: 262, loss is 1.7482385635375977
epoch: 2 step: 263, loss is 1.8558818101882935
epoch: 2 step: 264, loss is 1.8158581256866455
epoch: 2 step: 265, loss is 1.7993621826171875
epoch: 2 step: 266, loss is 1.7493383884429932
epoch: 2 step: 267, loss is 1.7819163799285889
epoch: 2 step: 268, loss is 1.7325220108032227
epoch: 2 step: 269, loss is 1.7179396152496338
epoch: 2 step: 270, loss is 1.7648624181747437
epoch: 2 step: 271, loss is 1.7591537237167358
epoch: 2 step: 272, loss is 1.801161289215088
epoch: 2 step: 273, loss is 1.7686935663223267
epoch: 2 step: 274, loss is 1.7774155139923096
epoch: 2 step: 275, loss is 1.7910723686218262
epoch: 2 step: 276, loss is 1.7400379180908203
epoch: 2 step: 277, loss is 1.7682232856750488
epoch: 2 step: 278, loss is 1.775517463684082
epoch: 2 step: 279, loss is 1.8035931587219238
epoch: 2 step: 280, loss is 1.840588092803955
epoch: 2 step: 281, loss is 1.8460052013397217
epoch: 2 step: 282, loss is 1.7497811317443848
epoch: 2 step: 283, loss is 1.7777003049850464
epoch: 2 step: 284, loss is 1.8939076662063599
epoch: 2 step: 285, loss is 1.7027175426483154
epoch: 2 step: 286, loss is 1.8455442190170288
epoch: 2 step: 287, loss is 1.820478916168213
epoch: 2 step: 288, loss is 1.8581182956695557
epoch: 2 step: 289, loss is 1.83876633644104
epoch: 2 step: 290, loss is 1.8466076850891113
epoch: 2 step: 291, loss is 1.734968900680542
epoch: 2 step: 292, loss is 1.7899976968765259
epoch: 2 step: 293, loss is 1.8471029996871948
epoch: 2 step: 294, loss is 1.7268532514572144
epoch: 2 step: 295, loss is 1.7490458488464355
epoch: 2 step: 296, loss is 1.7759581804275513
epoch: 2 step: 297, loss is 1.8594329357147217
epoch: 2 step: 298, loss is 1.8403699398040771
epoch: 2 step: 299, loss is 1.741633653640747
epoch: 2 step: 300, loss is 1.7870420217514038
epoch: 2 step: 301, loss is 1.6071045398712158
epoch: 2 step: 302, loss is 1.7415251731872559
epoch: 2 step: 303, loss is 1.9326395988464355
epoch: 2 step: 304, loss is 1.7688913345336914
epoch: 2 step: 305, loss is 1.8577096462249756
epoch: 2 step: 306, loss is 1.8196533918380737
epoch: 2 step: 307, loss is 1.7930805683135986
epoch: 2 step: 308, loss is 1.812243938446045
epoch: 2 step: 309, loss is 1.8090136051177979
epoch: 2 step: 310, loss is 1.8416868448257446
epoch: 2 step: 311, loss is 1.6595258712768555
epoch: 2 step: 312, loss is 1.7971715927124023
epoch: 2 step: 313, loss is 1.7158136367797852
epoch: 2 step: 314, loss is 1.7205451726913452
epoch: 2 step: 315, loss is 1.7460322380065918
epoch: 2 step: 316, loss is 1.806428074836731
epoch: 2 step: 317, loss is 1.854871153831482
epoch: 2 step: 318, loss is 1.7918264865875244
epoch: 2 step: 319, loss is 1.857133150100708
epoch: 2 step: 320, loss is 1.8619929552078247
epoch: 2 step: 321, loss is 1.8428970575332642
epoch: 2 step: 322, loss is 1.7478281259536743
epoch: 2 step: 323, loss is 1.7981542348861694
epoch: 2 step: 324, loss is 1.7388367652893066
epoch: 2 step: 325, loss is 1.8213354349136353
epoch: 2 step: 326, loss is 1.8172770738601685
epoch: 2 step: 327, loss is 1.8915636539459229
epoch: 2 step: 328, loss is 1.7727159261703491
epoch: 2 step: 329, loss is 1.672269344329834
epoch: 2 step: 330, loss is 1.7569057941436768
epoch: 2 step: 331, loss is 1.8469665050506592
epoch: 2 step: 332, loss is 1.7855149507522583
epoch: 2 step: 333, loss is 1.8072550296783447
epoch: 2 step: 334, loss is 1.7543696165084839
epoch: 2 step: 335, loss is 1.7541577816009521
epoch: 2 step: 336, loss is 1.7690460681915283
epoch: 2 step: 337, loss is 1.8139396905899048
epoch: 2 step: 338, loss is 1.790165662765503
epoch: 2 step: 339, loss is 1.8434514999389648
epoch: 2 step: 340, loss is 1.7970030307769775
epoch: 2 step: 341, loss is 1.709476351737976
epoch: 2 step: 342, loss is 1.6404789686203003
epoch: 2 step: 343, loss is 1.7340303659439087
epoch: 2 step: 344, loss is 1.7250415086746216
epoch: 2 step: 345, loss is 1.7795867919921875
epoch: 2 step: 346, loss is 1.7975354194641113
epoch: 2 step: 347, loss is 1.819098711013794
epoch: 2 step: 348, loss is 1.7630789279937744
epoch: 2 step: 349, loss is 1.8042688369750977
epoch: 2 step: 350, loss is 1.8980395793914795
epoch: 2 step: 351, loss is 1.6652143001556396
epoch: 2 step: 352, loss is 1.7915364503860474
epoch: 2 step: 353, loss is 1.7328461408615112
epoch: 2 step: 354, loss is 1.8572514057159424
epoch: 2 step: 355, loss is 1.7809345722198486
epoch: 2 step: 356, loss is 1.7125349044799805
epoch: 2 step: 357, loss is 1.739114761352539
epoch: 2 step: 358, loss is 1.7182954549789429
epoch: 2 step: 359, loss is 1.7481850385665894
epoch: 2 step: 360, loss is 1.8001917600631714
epoch: 2 step: 361, loss is 1.8127741813659668
epoch: 2 step: 362, loss is 1.743453025817871
epoch: 2 step: 363, loss is 1.9121342897415161
epoch: 2 step: 364, loss is 1.7708423137664795
epoch: 2 step: 365, loss is 1.8561033010482788
epoch: 2 step: 366, loss is 1.7166929244995117
epoch: 2 step: 367, loss is 1.7863839864730835
epoch: 2 step: 368, loss is 1.7913727760314941
epoch: 2 step: 369, loss is 1.8722734451293945
epoch: 2 step: 370, loss is 1.7932294607162476
epoch: 2 step: 371, loss is 1.767304539680481
epoch: 2 step: 372, loss is 1.7822246551513672
epoch: 2 step: 373, loss is 1.785788893699646
epoch: 2 step: 374, loss is 1.7219452857971191
epoch: 2 step: 375, loss is 1.7839324474334717
epoch: 2 step: 376, loss is 1.7981308698654175
epoch: 2 step: 377, loss is 1.7242521047592163
epoch: 2 step: 378, loss is 1.6721879243850708
epoch: 2 step: 379, loss is 1.729831576347351
epoch: 2 step: 380, loss is 1.8211615085601807
epoch: 2 step: 381, loss is 1.7687087059020996
epoch: 2 step: 382, loss is 1.6252415180206299
epoch: 2 step: 383, loss is 1.8604865074157715
epoch: 2 step: 384, loss is 1.7644091844558716
epoch: 2 step: 385, loss is 1.8360562324523926
epoch: 2 step: 386, loss is 1.708206295967102
epoch: 2 step: 387, loss is 1.7925338745117188
epoch: 2 step: 388, loss is 1.75800621509552
epoch: 2 step: 389, loss is 1.8554675579071045
epoch: 2 step: 390, loss is 1.727205514907837
Train epoch time: 146714.654 ms, per step time: 376.191 ms
epoch: 3 step: 1, loss is 1.7011620998382568
epoch: 3 step: 2, loss is 1.709585189819336
epoch: 3 step: 3, loss is 1.8116836547851562
epoch: 3 step: 4, loss is 1.6776773929595947
epoch: 3 step: 5, loss is 1.788543462753296
epoch: 3 step: 6, loss is 1.7659239768981934
epoch: 3 step: 7, loss is 1.7689886093139648
epoch: 3 step: 8, loss is 1.773811936378479
epoch: 3 step: 9, loss is 1.762629747390747
epoch: 3 step: 10, loss is 1.7795403003692627
epoch: 3 step: 11, loss is 1.635869026184082
epoch: 3 step: 12, loss is 1.7386810779571533
epoch: 3 step: 13, loss is 1.801726222038269
epoch: 3 step: 14, loss is 1.7301396131515503
epoch: 3 step: 15, loss is 1.758491039276123
epoch: 3 step: 16, loss is 1.7528533935546875
epoch: 3 step: 17, loss is 1.7193998098373413
epoch: 3 step: 18, loss is 1.7114794254302979
epoch: 3 step: 19, loss is 1.7316093444824219
epoch: 3 step: 20, loss is 1.7665354013442993
epoch: 3 step: 21, loss is 1.7559670209884644
epoch: 3 step: 22, loss is 1.7365370988845825
epoch: 3 step: 23, loss is 1.7789051532745361
epoch: 3 step: 24, loss is 1.7037755250930786
epoch: 3 step: 25, loss is 1.7753820419311523
epoch: 3 step: 26, loss is 1.7125160694122314
epoch: 3 step: 27, loss is 1.7486547231674194
epoch: 3 step: 28, loss is 1.7089030742645264
epoch: 3 step: 29, loss is 1.752793788909912
epoch: 3 step: 30, loss is 1.7486802339553833
epoch: 3 step: 31, loss is 1.8184881210327148
epoch: 3 step: 32, loss is 1.7876648902893066
epoch: 3 step: 33, loss is 1.9106379747390747
epoch: 3 step: 34, loss is 1.670501470565796
epoch: 3 step: 35, loss is 1.7234971523284912
epoch: 3 step: 36, loss is 1.7328954935073853
epoch: 3 step: 37, loss is 1.570358395576477
epoch: 3 step: 38, loss is 1.8367743492126465
epoch: 3 step: 39, loss is 1.677091360092163
epoch: 3 step: 40, loss is 1.7431758642196655
epoch: 3 step: 41, loss is 1.720537543296814
epoch: 3 step: 42, loss is 1.7712358236312866
epoch: 3 step: 43, loss is 1.8696388006210327
epoch: 3 step: 44, loss is 1.7342703342437744
epoch: 3 step: 45, loss is 1.8264657258987427
epoch: 3 step: 46, loss is 1.7379969358444214
epoch: 3 step: 47, loss is 1.7093908786773682
epoch: 3 step: 48, loss is 1.8232003450393677
epoch: 3 step: 49, loss is 1.6663192510604858
epoch: 3 step: 50, loss is 1.7718254327774048
epoch: 3 step: 51, loss is 1.7290639877319336
epoch: 3 step: 52, loss is 1.945021629333496
epoch: 3 step: 53, loss is 1.6861510276794434
epoch: 3 step: 54, loss is 1.6566307544708252
epoch: 3 step: 55, loss is 1.738956332206726
epoch: 3 step: 56, loss is 1.8005433082580566
epoch: 3 step: 57, loss is 1.6954911947250366
epoch: 3 step: 58, loss is 1.7621089220046997
epoch: 3 step: 59, loss is 1.6328601837158203
epoch: 3 step: 60, loss is 1.6871263980865479
epoch: 3 step: 61, loss is 1.7411242723464966
epoch: 3 step: 62, loss is 1.6336861848831177
epoch: 3 step: 63, loss is 1.7140171527862549
epoch: 3 step: 64, loss is 1.7042243480682373
epoch: 3 step: 65, loss is 1.7000048160552979
epoch: 3 step: 66, loss is 1.7662190198898315
epoch: 3 step: 67, loss is 1.6608006954193115
epoch: 3 step: 68, loss is 1.7941066026687622
epoch: 3 step: 69, loss is 1.763251543045044
epoch: 3 step: 70, loss is 1.741612195968628
epoch: 3 step: 71, loss is 1.6935865879058838
epoch: 3 step: 72, loss is 1.6148473024368286
epoch: 3 step: 73, loss is 1.7706575393676758
epoch: 3 step: 74, loss is 1.7908775806427002
epoch: 3 step: 75, loss is 1.7786062955856323
epoch: 3 step: 76, loss is 1.6840479373931885
epoch: 3 step: 77, loss is 1.6584070920944214
epoch: 3 step: 78, loss is 1.744150161743164
epoch: 3 step: 79, loss is 1.7607392072677612
epoch: 3 step: 80, loss is 1.7234543561935425
epoch: 3 step: 81, loss is 1.7030532360076904
epoch: 3 step: 82, loss is 1.6587423086166382
epoch: 3 step: 83, loss is 1.6688637733459473
epoch: 3 step: 84, loss is 1.6891658306121826
epoch: 3 step: 85, loss is 1.6793352365493774
epoch: 3 step: 86, loss is 1.8409228324890137
epoch: 3 step: 87, loss is 1.6615338325500488
epoch: 3 step: 88, loss is 1.7092816829681396
epoch: 3 step: 89, loss is 1.757968783378601
epoch: 3 step: 90, loss is 1.6591882705688477
epoch: 3 step: 91, loss is 1.8122813701629639
epoch: 3 step: 92, loss is 1.7917979955673218
epoch: 3 step: 93, loss is 1.7902584075927734
epoch: 3 step: 94, loss is 1.7348543405532837
epoch: 3 step: 95, loss is 1.7687939405441284
epoch: 3 step: 96, loss is 1.7550122737884521
epoch: 3 step: 97, loss is 1.7740013599395752
epoch: 3 step: 98, loss is 1.7521083354949951
epoch: 3 step: 99, loss is 1.703765630722046
epoch: 3 step: 100, loss is 1.7932658195495605
epoch: 3 step: 101, loss is 1.8472590446472168
epoch: 3 step: 102, loss is 1.79091215133667
epoch: 3 step: 103, loss is 1.673203706741333
epoch: 3 step: 104, loss is 1.884513258934021
epoch: 3 step: 105, loss is 1.7453030347824097
epoch: 3 step: 106, loss is 1.8236973285675049
epoch: 3 step: 107, loss is 1.8219428062438965
epoch: 3 step: 108, loss is 1.7206854820251465
epoch: 3 step: 109, loss is 1.7272353172302246
epoch: 3 step: 110, loss is 1.7556524276733398
epoch: 3 step: 111, loss is 1.6791669130325317
epoch: 3 step: 112, loss is 1.5839923620224
epoch: 3 step: 113, loss is 1.6792527437210083
epoch: 3 step: 114, loss is 1.7115216255187988
epoch: 3 step: 115, loss is 1.741018533706665
epoch: 3 step: 116, loss is 1.6030585765838623
epoch: 3 step: 117, loss is 1.7270311117172241
epoch: 3 step: 118, loss is 1.8239670991897583
epoch: 3 step: 119, loss is 1.7705371379852295
epoch: 3 step: 120, loss is 1.7693061828613281
epoch: 3 step: 121, loss is 1.7489126920700073
epoch: 3 step: 122, loss is 1.8197073936462402
epoch: 3 step: 123, loss is 1.6671526432037354
epoch: 3 step: 124, loss is 1.6562755107879639
epoch: 3 step: 125, loss is 1.8016430139541626
epoch: 3 step: 126, loss is 1.725387454032898
epoch: 3 step: 127, loss is 1.7734805345535278
epoch: 3 step: 128, loss is 1.7097951173782349
epoch: 3 step: 129, loss is 1.795531988143921
epoch: 3 step: 130, loss is 1.7153255939483643
epoch: 3 step: 131, loss is 1.7836263179779053
epoch: 3 step: 132, loss is 1.72203528881073
epoch: 3 step: 133, loss is 1.5638377666473389
epoch: 3 step: 134, loss is 1.674219012260437
epoch: 3 step: 135, loss is 1.7612909078598022
epoch: 3 step: 136, loss is 1.5711168050765991
epoch: 3 step: 137, loss is 1.756954550743103
epoch: 3 step: 138, loss is 1.8102957010269165
epoch: 3 step: 139, loss is 1.8309345245361328
epoch: 3 step: 140, loss is 1.6465873718261719
epoch: 3 step: 141, loss is 1.7128552198410034
epoch: 3 step: 142, loss is 1.7809803485870361
epoch: 3 step: 143, loss is 1.6547837257385254
epoch: 3 step: 144, loss is 1.7262263298034668
epoch: 3 step: 145, loss is 1.7470958232879639
epoch: 3 step: 146, loss is 1.7484445571899414
epoch: 3 step: 147, loss is 1.6253224611282349
epoch: 3 step: 148, loss is 1.6575924158096313
epoch: 3 step: 149, loss is 1.7495793104171753
epoch: 3 step: 150, loss is 1.6102755069732666
epoch: 3 step: 151, loss is 1.692030429840088
epoch: 3 step: 152, loss is 1.8191797733306885
epoch: 3 step: 153, loss is 1.7927303314208984
epoch: 3 step: 154, loss is 1.7503738403320312
epoch: 3 step: 155, loss is 1.5868735313415527
epoch: 3 step: 156, loss is 1.6121134757995605
epoch: 3 step: 157, loss is 1.6651256084442139
epoch: 3 step: 158, loss is 1.6786134243011475
epoch: 3 step: 159, loss is 1.6842710971832275
epoch: 3 step: 160, loss is 1.7657203674316406
epoch: 3 step: 161, loss is 1.7009081840515137
epoch: 3 step: 162, loss is 1.79721999168396
epoch: 3 step: 163, loss is 1.6888189315795898
epoch: 3 step: 164, loss is 1.6185917854309082
epoch: 3 step: 165, loss is 1.6772339344024658
epoch: 3 step: 166, loss is 1.8670486211776733
epoch: 3 step: 167, loss is 1.7415584325790405
epoch: 3 step: 168, loss is 1.7244024276733398
epoch: 3 step: 169, loss is 1.6642749309539795
epoch: 3 step: 170, loss is 1.7691022157669067
epoch: 3 step: 171, loss is 1.7506186962127686
epoch: 3 step: 172, loss is 1.8031151294708252
epoch: 3 step: 173, loss is 1.734635829925537
epoch: 3 step: 174, loss is 1.8045151233673096
epoch: 3 step: 175, loss is 1.7057240009307861
epoch: 3 step: 176, loss is 1.851115345954895
epoch: 3 step: 177, loss is 1.7931866645812988
epoch: 3 step: 178, loss is 1.6450235843658447
epoch: 3 step: 179, loss is 1.6944745779037476
epoch: 3 step: 180, loss is 1.709342360496521
epoch: 3 step: 181, loss is 1.7119970321655273
epoch: 3 step: 182, loss is 1.7331335544586182
epoch: 3 step: 183, loss is 1.7393572330474854
epoch: 3 step: 184, loss is 1.730353593826294
epoch: 3 step: 185, loss is 1.6934131383895874
epoch: 3 step: 186, loss is 1.6065739393234253
epoch: 3 step: 187, loss is 1.736080527305603
epoch: 3 step: 188, loss is 1.8110605478286743
epoch: 3 step: 189, loss is 1.750079870223999
epoch: 3 step: 190, loss is 1.849204659461975
epoch: 3 step: 191, loss is 1.8364250659942627
epoch: 3 step: 192, loss is 1.8111459016799927
epoch: 3 step: 193, loss is 1.6181663274765015
epoch: 3 step: 194, loss is 1.7094552516937256
epoch: 3 step: 195, loss is 1.7747726440429688
epoch: 3 step: 196, loss is 1.7307028770446777
epoch: 3 step: 197, loss is 1.7216401100158691
epoch: 3 step: 198, loss is 1.7639442682266235
epoch: 3 step: 199, loss is 1.762778878211975
epoch: 3 step: 200, loss is 1.6946556568145752
epoch: 3 step: 201, loss is 1.823284387588501
epoch: 3 step: 202, loss is 1.8446054458618164
epoch: 3 step: 203, loss is 1.7421748638153076
epoch: 3 step: 204, loss is 1.715840220451355
epoch: 3 step: 205, loss is 1.8044579029083252
epoch: 3 step: 206, loss is 1.6762937307357788
epoch: 3 step: 207, loss is 1.7598419189453125
epoch: 3 step: 208, loss is 1.6516340970993042
epoch: 3 step: 209, loss is 1.682855486869812
epoch: 3 step: 210, loss is 1.74543297290802
epoch: 3 step: 211, loss is 1.6141259670257568
epoch: 3 step: 212, loss is 1.7569105625152588
epoch: 3 step: 213, loss is 1.80369234085083
epoch: 3 step: 214, loss is 1.7974480390548706
epoch: 3 step: 215, loss is 1.6719825267791748
epoch: 3 step: 216, loss is 1.7092519998550415
epoch: 3 step: 217, loss is 1.6810848712921143
epoch: 3 step: 218, loss is 1.6612707376480103
epoch: 3 step: 219, loss is 1.676425576210022
epoch: 3 step: 220, loss is 1.682099461555481
epoch: 3 step: 221, loss is 1.6157437562942505
epoch: 3 step: 222, loss is 1.7566237449645996
epoch: 3 step: 223, loss is 1.657416582107544
epoch: 3 step: 224, loss is 1.8284307718276978
epoch: 3 step: 225, loss is 1.6203669309616089
epoch: 3 step: 226, loss is 1.7494653463363647
epoch: 3 step: 227, loss is 1.6093566417694092
epoch: 3 step: 228, loss is 1.7360996007919312
epoch: 3 step: 229, loss is 1.6871085166931152
epoch: 3 step: 230, loss is 1.6777480840682983
epoch: 3 step: 231, loss is 1.6204313039779663
epoch: 3 step: 232, loss is 1.6421411037445068
epoch: 3 step: 233, loss is 1.7940982580184937
epoch: 3 step: 234, loss is 1.8112177848815918
epoch: 3 step: 235, loss is 1.777404546737671
epoch: 3 step: 236, loss is 1.751941442489624
epoch: 3 step: 237, loss is 1.66073477268219
epoch: 3 step: 238, loss is 1.7887825965881348
epoch: 3 step: 239, loss is 1.7797002792358398
epoch: 3 step: 240, loss is 1.7799899578094482
epoch: 3 step: 241, loss is 1.7370054721832275
epoch: 3 step: 242, loss is 1.7495089769363403
epoch: 3 step: 243, loss is 1.6545119285583496
epoch: 3 step: 244, loss is 1.6990618705749512
epoch: 3 step: 245, loss is 1.6874079704284668
epoch: 3 step: 246, loss is 1.7034668922424316
epoch: 3 step: 247, loss is 1.7423933744430542
epoch: 3 step: 248, loss is 1.6429753303527832
epoch: 3 step: 249, loss is 1.7862322330474854
epoch: 3 step: 250, loss is 1.6871455907821655
epoch: 3 step: 251, loss is 1.631443977355957
epoch: 3 step: 252, loss is 1.8158299922943115
epoch: 3 step: 253, loss is 1.6765892505645752
epoch: 3 step: 254, loss is 1.7132534980773926
epoch: 3 step: 255, loss is 1.7466295957565308
epoch: 3 step: 256, loss is 1.7446067333221436
epoch: 3 step: 257, loss is 1.7640235424041748
epoch: 3 step: 258, loss is 1.723985195159912
epoch: 3 step: 259, loss is 1.7540901899337769
epoch: 3 step: 260, loss is 1.7146706581115723
epoch: 3 step: 261, loss is 1.6645374298095703
epoch: 3 step: 262, loss is 1.7577040195465088
epoch: 3 step: 263, loss is 1.66899836063385
epoch: 3 step: 264, loss is 1.7084097862243652
epoch: 3 step: 265, loss is 1.7684085369110107
epoch: 3 step: 266, loss is 1.7932184934616089
epoch: 3 step: 267, loss is 1.8327465057373047
epoch: 3 step: 268, loss is 1.7211740016937256
epoch: 3 step: 269, loss is 1.5827194452285767
epoch: 3 step: 270, loss is 1.69878351688385
epoch: 3 step: 271, loss is 1.643484354019165
epoch: 3 step: 272, loss is 1.6335159540176392
epoch: 3 step: 273, loss is 1.6185886859893799
epoch: 3 step: 274, loss is 1.8152391910552979
epoch: 3 step: 275, loss is 1.7642897367477417
epoch: 3 step: 276, loss is 1.672631025314331
epoch: 3 step: 277, loss is 1.776371955871582
epoch: 3 step: 278, loss is 1.635124683380127
epoch: 3 step: 279, loss is 1.7019803524017334
epoch: 3 step: 280, loss is 1.59541916847229
epoch: 3 step: 281, loss is 1.6703786849975586
epoch: 3 step: 282, loss is 1.6549696922302246
epoch: 3 step: 283, loss is 1.688324213027954
epoch: 3 step: 284, loss is 1.6780511140823364
epoch: 3 step: 285, loss is 1.7514989376068115
epoch: 3 step: 286, loss is 1.7574325799942017
epoch: 3 step: 287, loss is 1.8168704509735107
epoch: 3 step: 288, loss is 1.7487471103668213
epoch: 3 step: 289, loss is 1.5962984561920166
epoch: 3 step: 290, loss is 1.5820927619934082
epoch: 3 step: 291, loss is 1.606921911239624
epoch: 3 step: 292, loss is 1.6841962337493896
epoch: 3 step: 293, loss is 1.754068374633789
epoch: 3 step: 294, loss is 1.6151689291000366
epoch: 3 step: 295, loss is 1.6386445760726929
epoch: 3 step: 296, loss is 1.6499595642089844
epoch: 3 step: 297, loss is 1.722853422164917
epoch: 3 step: 298, loss is 1.7466785907745361
epoch: 3 step: 299, loss is 1.6606475114822388
epoch: 3 step: 300, loss is 1.7273977994918823
epoch: 3 step: 301, loss is 1.7580937147140503
epoch: 3 step: 302, loss is 1.7870055437088013
epoch: 3 step: 303, loss is 1.6448380947113037
epoch: 3 step: 304, loss is 1.7750073671340942
epoch: 3 step: 305, loss is 1.7694441080093384
epoch: 3 step: 306, loss is 1.659514307975769
epoch: 3 step: 307, loss is 1.6545149087905884
epoch: 3 step: 308, loss is 1.6121965646743774
epoch: 3 step: 309, loss is 1.576978087425232
epoch: 3 step: 310, loss is 1.6596425771713257
epoch: 3 step: 311, loss is 1.7404265403747559
epoch: 3 step: 312, loss is 1.6407229900360107
epoch: 3 step: 313, loss is 1.6929740905761719
epoch: 3 step: 314, loss is 1.7453160285949707
epoch: 3 step: 315, loss is 1.6856878995895386
epoch: 3 step: 316, loss is 1.7691365480422974
epoch: 3 step: 317, loss is 1.6508985757827759
epoch: 3 step: 318, loss is 1.707797646522522
epoch: 3 step: 319, loss is 1.6020218133926392
epoch: 3 step: 320, loss is 1.6246623992919922
epoch: 3 step: 321, loss is 1.723819613456726
epoch: 3 step: 322, loss is 1.6295673847198486
epoch: 3 step: 323, loss is 1.7643470764160156
epoch: 3 step: 324, loss is 1.6471282243728638
epoch: 3 step: 325, loss is 1.6642776727676392
epoch: 3 step: 326, loss is 1.7263293266296387
epoch: 3 step: 327, loss is 1.603219747543335
epoch: 3 step: 328, loss is 1.6895296573638916
epoch: 3 step: 329, loss is 1.696844220161438
epoch: 3 step: 330, loss is 1.7557255029678345
epoch: 3 step: 331, loss is 1.6828575134277344
epoch: 3 step: 332, loss is 1.5993280410766602
epoch: 3 step: 333, loss is 1.6977688074111938
epoch: 3 step: 334, loss is 1.6396527290344238
epoch: 3 step: 335, loss is 1.672852635383606
epoch: 3 step: 336, loss is 1.6235517263412476
epoch: 3 step: 337, loss is 1.6519783735275269
epoch: 3 step: 338, loss is 1.6986744403839111
epoch: 3 step: 339, loss is 1.584930419921875
epoch: 3 step: 340, loss is 1.6436800956726074
epoch: 3 step: 341, loss is 1.7935266494750977
epoch: 3 step: 342, loss is 1.653383731842041
epoch: 3 step: 343, loss is 1.7505847215652466
epoch: 3 step: 344, loss is 1.6373629570007324
epoch: 3 step: 345, loss is 1.7726969718933105
epoch: 3 step: 346, loss is 1.75229811668396
epoch: 3 step: 347, loss is 1.623457670211792
epoch: 3 step: 348, loss is 1.6465024948120117
epoch: 3 step: 349, loss is 1.7680888175964355
epoch: 3 step: 350, loss is 1.6510112285614014
epoch: 3 step: 351, loss is 1.5867525339126587
epoch: 3 step: 352, loss is 1.6209733486175537
epoch: 3 step: 353, loss is 1.7414778470993042
epoch: 3 step: 354, loss is 1.611627459526062
epoch: 3 step: 355, loss is 1.6444076299667358
epoch: 3 step: 356, loss is 1.6516218185424805
epoch: 3 step: 357, loss is 1.629439353942871
epoch: 3 step: 358, loss is 1.7086539268493652
epoch: 3 step: 359, loss is 1.7392237186431885
epoch: 3 step: 360, loss is 1.61751389503479
epoch: 3 step: 361, loss is 1.7832938432693481
epoch: 3 step: 362, loss is 1.7391486167907715
epoch: 3 step: 363, loss is 1.6941441297531128
epoch: 3 step: 364, loss is 1.5541343688964844
epoch: 3 step: 365, loss is 1.6111410856246948
epoch: 3 step: 366, loss is 1.7353687286376953
epoch: 3 step: 367, loss is 1.5854017734527588
epoch: 3 step: 368, loss is 1.751041054725647
epoch: 3 step: 369, loss is 1.5958932638168335
epoch: 3 step: 370, loss is 1.7284789085388184
epoch: 3 step: 371, loss is 1.5446407794952393
epoch: 3 step: 372, loss is 1.7027759552001953
epoch: 3 step: 373, loss is 1.7276140451431274
epoch: 3 step: 374, loss is 1.674797773361206
epoch: 3 step: 375, loss is 1.6916526556015015
epoch: 3 step: 376, loss is 1.6954025030136108
epoch: 3 step: 377, loss is 1.7218893766403198
epoch: 3 step: 378, loss is 1.708600640296936
epoch: 3 step: 379, loss is 1.6732304096221924
epoch: 3 step: 380, loss is 1.5969635248184204
epoch: 3 step: 381, loss is 1.661347508430481
epoch: 3 step: 382, loss is 1.622877836227417
epoch: 3 step: 383, loss is 1.7991561889648438
epoch: 3 step: 384, loss is 1.6844978332519531
epoch: 3 step: 385, loss is 1.6307588815689087
epoch: 3 step: 386, loss is 1.7286391258239746
epoch: 3 step: 387, loss is 1.674031376838684
epoch: 3 step: 388, loss is 1.703923225402832
epoch: 3 step: 389, loss is 1.7576723098754883
epoch: 3 step: 390, loss is 1.8129860162734985
Train epoch time: 148262.046 ms, per step time: 380.159 ms
epoch: 4 step: 1, loss is 1.6795077323913574
epoch: 4 step: 2, loss is 1.6547783613204956
epoch: 4 step: 3, loss is 1.7279499769210815
epoch: 4 step: 4, loss is 1.6031444072723389
epoch: 4 step: 5, loss is 1.7750589847564697
epoch: 4 step: 6, loss is 1.6007839441299438
epoch: 4 step: 7, loss is 1.5680683851242065
epoch: 4 step: 8, loss is 1.6678814888000488
epoch: 4 step: 9, loss is 1.6794886589050293
epoch: 4 step: 10, loss is 1.761857032775879
epoch: 4 step: 11, loss is 1.732496738433838
epoch: 4 step: 12, loss is 1.770388126373291
epoch: 4 step: 13, loss is 1.7412445545196533
epoch: 4 step: 14, loss is 1.6036450862884521
epoch: 4 step: 15, loss is 1.5930733680725098
epoch: 4 step: 16, loss is 1.5691652297973633
epoch: 4 step: 17, loss is 1.6956709623336792
epoch: 4 step: 18, loss is 1.7173221111297607
epoch: 4 step: 19, loss is 1.5947833061218262
epoch: 4 step: 20, loss is 1.6035970449447632
epoch: 4 step: 21, loss is 1.7363462448120117
epoch: 4 step: 22, loss is 1.6880065202713013
epoch: 4 step: 23, loss is 1.762622356414795
epoch: 4 step: 24, loss is 1.7345247268676758
epoch: 4 step: 25, loss is 1.5922328233718872
epoch: 4 step: 26, loss is 1.6273783445358276
epoch: 4 step: 27, loss is 1.6727008819580078
epoch: 4 step: 28, loss is 1.6469426155090332
epoch: 4 step: 29, loss is 1.7373484373092651
epoch: 4 step: 30, loss is 1.574809193611145
epoch: 4 step: 31, loss is 1.6229474544525146
epoch: 4 step: 32, loss is 1.7641392946243286
epoch: 4 step: 33, loss is 1.656707525253296
epoch: 4 step: 34, loss is 1.7575713396072388
epoch: 4 step: 35, loss is 1.5988682508468628
epoch: 4 step: 36, loss is 1.7322921752929688
epoch: 4 step: 37, loss is 1.7096195220947266
epoch: 4 step: 38, loss is 1.7116434574127197
epoch: 4 step: 39, loss is 1.6372380256652832
epoch: 4 step: 40, loss is 1.6533660888671875
epoch: 4 step: 41, loss is 1.6443336009979248
epoch: 4 step: 42, loss is 1.610018014907837
epoch: 4 step: 43, loss is 1.7536451816558838
epoch: 4 step: 44, loss is 1.6489262580871582
epoch: 4 step: 45, loss is 1.6798604726791382
epoch: 4 step: 46, loss is 1.7339818477630615
epoch: 4 step: 47, loss is 1.6021127700805664
epoch: 4 step: 48, loss is 1.6183156967163086
epoch: 4 step: 49, loss is 1.662222981452942
epoch: 4 step: 50, loss is 1.5940203666687012
epoch: 4 step: 51, loss is 1.6162567138671875
epoch: 4 step: 52, loss is 1.6631288528442383
epoch: 4 step: 53, loss is 1.6619668006896973
epoch: 4 step: 54, loss is 1.7513457536697388
epoch: 4 step: 55, loss is 1.644848108291626
epoch: 4 step: 56, loss is 1.7595937252044678
epoch: 4 step: 57, loss is 1.721566915512085
epoch: 4 step: 58, loss is 1.7103322744369507
epoch: 4 step: 59, loss is 1.6535199880599976
epoch: 4 step: 60, loss is 1.7302041053771973
epoch: 4 step: 61, loss is 1.6460810899734497
epoch: 4 step: 62, loss is 1.5453379154205322
epoch: 4 step: 63, loss is 1.5611209869384766
epoch: 4 step: 64, loss is 1.7811839580535889
epoch: 4 step: 65, loss is 1.761525273323059
epoch: 4 step: 66, loss is 1.7321281433105469
epoch: 4 step: 67, loss is 1.8064043521881104
epoch: 4 step: 68, loss is 1.4971861839294434
epoch: 4 step: 69, loss is 1.5854825973510742
epoch: 4 step: 70, loss is 1.6636431217193604
epoch: 4 step: 71, loss is 1.7476249933242798
epoch: 4 step: 72, loss is 1.6772658824920654
epoch: 4 step: 73, loss is 1.5931023359298706
epoch: 4 step: 74, loss is 1.6290149688720703
epoch: 4 step: 75, loss is 1.661523699760437
epoch: 4 step: 76, loss is 1.6970641613006592
epoch: 4 step: 77, loss is 1.5608322620391846
epoch: 4 step: 78, loss is 1.5832866430282593
epoch: 4 step: 79, loss is 1.7477389574050903
epoch: 4 step: 80, loss is 1.6249845027923584
epoch: 4 step: 81, loss is 1.644599199295044
epoch: 4 step: 82, loss is 1.7964465618133545
epoch: 4 step: 83, loss is 1.4869182109832764
epoch: 4 step: 84, loss is 1.5899062156677246
epoch: 4 step: 85, loss is 1.7660880088806152
epoch: 4 step: 86, loss is 1.564998745918274
epoch: 4 step: 87, loss is 1.6991653442382812
epoch: 4 step: 88, loss is 1.7094069719314575
epoch: 4 step: 89, loss is 1.6553308963775635
epoch: 4 step: 90, loss is 1.6726868152618408
epoch: 4 step: 91, loss is 1.5828264951705933
epoch: 4 step: 92, loss is 1.6503878831863403
epoch: 4 step: 93, loss is 1.677264928817749
epoch: 4 step: 94, loss is 1.579552173614502
epoch: 4 step: 95, loss is 1.6640183925628662
epoch: 4 step: 96, loss is 1.7751851081848145
epoch: 4 step: 97, loss is 1.5866656303405762
epoch: 4 step: 98, loss is 1.7393475770950317
epoch: 4 step: 99, loss is 1.633880615234375
epoch: 4 step: 100, loss is 1.7455921173095703
epoch: 4 step: 101, loss is 1.6347025632858276
epoch: 4 step: 102, loss is 1.5147490501403809
epoch: 4 step: 103, loss is 1.752202033996582
epoch: 4 step: 104, loss is 1.6515647172927856
epoch: 4 step: 105, loss is 1.632629632949829
epoch: 4 step: 106, loss is 1.6201138496398926
epoch: 4 step: 107, loss is 1.5832629203796387
epoch: 4 step: 108, loss is 1.7153831720352173
epoch: 4 step: 109, loss is 1.6719881296157837
epoch: 4 step: 110, loss is 1.6569435596466064
epoch: 4 step: 111, loss is 1.6852543354034424
epoch: 4 step: 112, loss is 1.5864040851593018
epoch: 4 step: 113, loss is 1.6995000839233398
epoch: 4 step: 114, loss is 1.716445803642273
epoch: 4 step: 115, loss is 1.6923251152038574
epoch: 4 step: 116, loss is 1.672752022743225
epoch: 4 step: 117, loss is 1.6818265914916992
epoch: 4 step: 118, loss is 1.6175833940505981
epoch: 4 step: 119, loss is 1.5711658000946045
epoch: 4 step: 120, loss is 1.5915007591247559
epoch: 4 step: 121, loss is 1.677505612373352
epoch: 4 step: 122, loss is 1.6880251169204712
epoch: 4 step: 123, loss is 1.7808399200439453
epoch: 4 step: 124, loss is 1.6556986570358276
epoch: 4 step: 125, loss is 1.6531155109405518
epoch: 4 step: 126, loss is 1.674384355545044
epoch: 4 step: 127, loss is 1.6800332069396973
epoch: 4 step: 128, loss is 1.758007287979126
epoch: 4 step: 129, loss is 1.6447930335998535
epoch: 4 step: 130, loss is 1.5922729969024658
epoch: 4 step: 131, loss is 1.6862068176269531
epoch: 4 step: 132, loss is 1.6075475215911865
epoch: 4 step: 133, loss is 1.706351399421692
epoch: 4 step: 134, loss is 1.6015820503234863
epoch: 4 step: 135, loss is 1.6301424503326416
epoch: 4 step: 136, loss is 1.681418538093567
epoch: 4 step: 137, loss is 1.7134361267089844
epoch: 4 step: 138, loss is 1.6769747734069824
epoch: 4 step: 139, loss is 1.6933972835540771
epoch: 4 step: 140, loss is 1.7027777433395386
epoch: 4 step: 141, loss is 1.6266554594039917
epoch: 4 step: 142, loss is 1.533184289932251
epoch: 4 step: 143, loss is 1.6575446128845215
epoch: 4 step: 144, loss is 1.6996803283691406
epoch: 4 step: 145, loss is 1.6061861515045166
epoch: 4 step: 146, loss is 1.639892816543579
epoch: 4 step: 147, loss is 1.6984308958053589
epoch: 4 step: 148, loss is 1.6266686916351318
epoch: 4 step: 149, loss is 1.5342903137207031
epoch: 4 step: 150, loss is 1.5115759372711182
epoch: 4 step: 151, loss is 1.6471586227416992
epoch: 4 step: 152, loss is 1.6123220920562744
epoch: 4 step: 153, loss is 1.7298158407211304
epoch: 4 step: 154, loss is 1.7565926313400269
epoch: 4 step: 155, loss is 1.5588290691375732
epoch: 4 step: 156, loss is 1.6925194263458252
epoch: 4 step: 157, loss is 1.6091477870941162
epoch: 4 step: 158, loss is 1.546489953994751
epoch: 4 step: 159, loss is 1.58791983127594
epoch: 4 step: 160, loss is 1.6596553325653076
epoch: 4 step: 161, loss is 1.6580784320831299
epoch: 4 step: 162, loss is 1.6155848503112793
epoch: 4 step: 163, loss is 1.507157564163208
epoch: 4 step: 164, loss is 1.6309540271759033
epoch: 4 step: 165, loss is 1.718853235244751
epoch: 4 step: 166, loss is 1.6585129499435425
epoch: 4 step: 167, loss is 1.690779209136963
epoch: 4 step: 168, loss is 1.8036307096481323
epoch: 4 step: 169, loss is 1.6397377252578735
epoch: 4 step: 170, loss is 1.6578072309494019
epoch: 4 step: 171, loss is 1.6795103549957275
epoch: 4 step: 172, loss is 1.6370961666107178
epoch: 4 step: 173, loss is 1.7581895589828491
epoch: 4 step: 174, loss is 1.7030394077301025
epoch: 4 step: 175, loss is 1.6070623397827148
epoch: 4 step: 176, loss is 1.7415907382965088
epoch: 4 step: 177, loss is 1.6105648279190063
epoch: 4 step: 178, loss is 1.6762864589691162
epoch: 4 step: 179, loss is 1.731890320777893
epoch: 4 step: 180, loss is 1.7985610961914062
epoch: 4 step: 181, loss is 1.672331690788269
epoch: 4 step: 182, loss is 1.5811572074890137
epoch: 4 step: 183, loss is 1.6050021648406982
epoch: 4 step: 184, loss is 1.6705358028411865
epoch: 4 step: 185, loss is 1.7438170909881592
epoch: 4 step: 186, loss is 1.6789922714233398
epoch: 4 step: 187, loss is 1.758981704711914
epoch: 4 step: 188, loss is 1.7176316976547241
epoch: 4 step: 189, loss is 1.7206790447235107
epoch: 4 step: 190, loss is 1.6082700490951538
epoch: 4 step: 191, loss is 1.7374762296676636
epoch: 4 step: 192, loss is 1.617494821548462
epoch: 4 step: 193, loss is 1.689945101737976
epoch: 4 step: 194, loss is 1.8073633909225464
epoch: 4 step: 195, loss is 1.7421574592590332
epoch: 4 step: 196, loss is 1.732161283493042
epoch: 4 step: 197, loss is 1.6360026597976685
epoch: 4 step: 198, loss is 1.6759116649627686
epoch: 4 step: 199, loss is 1.700387954711914
epoch: 4 step: 200, loss is 1.8368773460388184
epoch: 4 step: 201, loss is 1.652057409286499
epoch: 4 step: 202, loss is 1.628521203994751
epoch: 4 step: 203, loss is 1.7547571659088135
epoch: 4 step: 204, loss is 1.654246211051941
epoch: 4 step: 205, loss is 1.6919186115264893
epoch: 4 step: 206, loss is 1.6392292976379395
epoch: 4 step: 207, loss is 1.6958022117614746
epoch: 4 step: 208, loss is 1.7535076141357422
epoch: 4 step: 209, loss is 1.7452285289764404
epoch: 4 step: 210, loss is 1.6547497510910034
epoch: 4 step: 211, loss is 1.719916582107544
epoch: 4 step: 212, loss is 1.7590301036834717
epoch: 4 step: 213, loss is 1.7129522562026978
epoch: 4 step: 214, loss is 1.670893669128418
epoch: 4 step: 215, loss is 1.6782957315444946
epoch: 4 step: 216, loss is 1.5753823518753052
epoch: 4 step: 217, loss is 1.6813194751739502
epoch: 4 step: 218, loss is 1.7322540283203125
epoch: 4 step: 219, loss is 1.5934405326843262
epoch: 4 step: 220, loss is 1.5727678537368774
epoch: 4 step: 221, loss is 1.6348683834075928
epoch: 4 step: 222, loss is 1.8156979084014893
epoch: 4 step: 223, loss is 1.725146770477295
epoch: 4 step: 224, loss is 1.6986759901046753
epoch: 4 step: 225, loss is 1.5316214561462402
epoch: 4 step: 226, loss is 1.7105189561843872
epoch: 4 step: 227, loss is 1.7322063446044922
epoch: 4 step: 228, loss is 1.6162315607070923
epoch: 4 step: 229, loss is 1.6170920133590698
epoch: 4 step: 230, loss is 1.571286678314209
epoch: 4 step: 231, loss is 1.662850022315979
epoch: 4 step: 232, loss is 1.6746301651000977
epoch: 4 step: 233, loss is 1.6915740966796875
epoch: 4 step: 234, loss is 1.6917881965637207
epoch: 4 step: 235, loss is 1.5960277318954468
epoch: 4 step: 236, loss is 1.7773040533065796
epoch: 4 step: 237, loss is 1.6099199056625366
epoch: 4 step: 238, loss is 1.6929254531860352
epoch: 4 step: 239, loss is 1.668748378753662
epoch: 4 step: 240, loss is 1.7303643226623535
epoch: 4 step: 241, loss is 1.697258472442627
epoch: 4 step: 242, loss is 1.63413405418396
epoch: 4 step: 243, loss is 1.7441484928131104
epoch: 4 step: 244, loss is 1.6545597314834595
epoch: 4 step: 245, loss is 1.7197586297988892
epoch: 4 step: 246, loss is 1.7145718336105347
epoch: 4 step: 247, loss is 1.690746784210205
epoch: 4 step: 248, loss is 1.5470335483551025
epoch: 4 step: 249, loss is 1.5790351629257202
epoch: 4 step: 250, loss is 1.73712158203125
epoch: 4 step: 251, loss is 1.617708683013916
epoch: 4 step: 252, loss is 1.6456297636032104
epoch: 4 step: 253, loss is 1.5848875045776367
epoch: 4 step: 254, loss is 1.6635351181030273
epoch: 4 step: 255, loss is 1.709910273551941
epoch: 4 step: 256, loss is 1.5659290552139282
epoch: 4 step: 257, loss is 1.6554642915725708
epoch: 4 step: 258, loss is 1.6749262809753418
epoch: 4 step: 259, loss is 1.7221397161483765
epoch: 4 step: 260, loss is 1.632184386253357
epoch: 4 step: 261, loss is 1.7132208347320557
epoch: 4 step: 262, loss is 1.7299683094024658
epoch: 4 step: 263, loss is 1.5776056051254272
epoch: 4 step: 264, loss is 1.5865495204925537
epoch: 4 step: 265, loss is 1.6605943441390991
epoch: 4 step: 266, loss is 1.689653992652893
epoch: 4 step: 267, loss is 1.715151071548462
epoch: 4 step: 268, loss is 1.6336978673934937
epoch: 4 step: 269, loss is 1.6838366985321045
epoch: 4 step: 270, loss is 1.6349806785583496
epoch: 4 step: 271, loss is 1.6629189252853394
epoch: 4 step: 272, loss is 1.6586644649505615
epoch: 4 step: 273, loss is 1.6890497207641602
epoch: 4 step: 274, loss is 1.7127034664154053
epoch: 4 step: 275, loss is 1.5627331733703613
epoch: 4 step: 276, loss is 1.7165720462799072
epoch: 4 step: 277, loss is 1.6123813390731812
epoch: 4 step: 278, loss is 1.6511244773864746
epoch: 4 step: 279, loss is 1.6871031522750854
epoch: 4 step: 280, loss is 1.7347900867462158
epoch: 4 step: 281, loss is 1.6092252731323242
epoch: 4 step: 282, loss is 1.5915873050689697
epoch: 4 step: 283, loss is 1.8185259103775024
epoch: 4 step: 284, loss is 1.6477389335632324
epoch: 4 step: 285, loss is 1.5396977663040161
epoch: 4 step: 286, loss is 1.6856896877288818
epoch: 4 step: 287, loss is 1.7729679346084595
epoch: 4 step: 288, loss is 1.6553504467010498
epoch: 4 step: 289, loss is 1.6827948093414307
epoch: 4 step: 290, loss is 1.6195402145385742
epoch: 4 step: 291, loss is 1.6322977542877197
epoch: 4 step: 292, loss is 1.6847572326660156
epoch: 4 step: 293, loss is 1.7206950187683105
epoch: 4 step: 294, loss is 1.6763403415679932
epoch: 4 step: 295, loss is 1.6264344453811646
epoch: 4 step: 296, loss is 1.7561956644058228
epoch: 4 step: 297, loss is 1.5715850591659546
epoch: 4 step: 298, loss is 1.7612639665603638
epoch: 4 step: 299, loss is 1.6326185464859009
epoch: 4 step: 300, loss is 1.6466399431228638
epoch: 4 step: 301, loss is 1.6066995859146118
epoch: 4 step: 302, loss is 1.7112300395965576
epoch: 4 step: 303, loss is 1.6105363368988037
epoch: 4 step: 304, loss is 1.530107021331787
epoch: 4 step: 305, loss is 1.6013085842132568
epoch: 4 step: 306, loss is 1.6118366718292236
epoch: 4 step: 307, loss is 1.6645658016204834
epoch: 4 step: 308, loss is 1.6785011291503906
epoch: 4 step: 309, loss is 1.4574995040893555
epoch: 4 step: 310, loss is 1.5936414003372192
epoch: 4 step: 311, loss is 1.5958086252212524
epoch: 4 step: 312, loss is 1.5721819400787354
epoch: 4 step: 313, loss is 1.759324550628662
epoch: 4 step: 314, loss is 1.620968222618103
epoch: 4 step: 315, loss is 1.5704362392425537
epoch: 4 step: 316, loss is 1.5149422883987427
epoch: 4 step: 317, loss is 1.7801661491394043
epoch: 4 step: 318, loss is 1.5842822790145874
epoch: 4 step: 319, loss is 1.5942542552947998
epoch: 4 step: 320, loss is 1.6681196689605713
epoch: 4 step: 321, loss is 1.6195734739303589
epoch: 4 step: 322, loss is 1.7136046886444092
epoch: 4 step: 323, loss is 1.6229065656661987
epoch: 4 step: 324, loss is 1.535475492477417
epoch: 4 step: 325, loss is 1.6061768531799316
epoch: 4 step: 326, loss is 1.5861411094665527
epoch: 4 step: 327, loss is 1.5017728805541992
epoch: 4 step: 328, loss is 1.6840782165527344
epoch: 4 step: 329, loss is 1.5316271781921387
epoch: 4 step: 330, loss is 1.6583795547485352
epoch: 4 step: 331, loss is 1.548046350479126
epoch: 4 step: 332, loss is 1.7077082395553589
epoch: 4 step: 333, loss is 1.6904962062835693
epoch: 4 step: 334, loss is 1.7003014087677002
epoch: 4 step: 335, loss is 1.7024219036102295
epoch: 4 step: 336, loss is 1.6112732887268066
epoch: 4 step: 337, loss is 1.6688255071640015
epoch: 4 step: 338, loss is 1.615283489227295
epoch: 4 step: 339, loss is 1.5577867031097412
epoch: 4 step: 340, loss is 1.51299250125885
epoch: 4 step: 341, loss is 1.6837174892425537
epoch: 4 step: 342, loss is 1.5802252292633057
epoch: 4 step: 343, loss is 1.5930207967758179
epoch: 4 step: 344, loss is 1.6317377090454102
epoch: 4 step: 345, loss is 1.4732263088226318
epoch: 4 step: 346, loss is 1.5824275016784668
epoch: 4 step: 347, loss is 1.607077717781067
epoch: 4 step: 348, loss is 1.68470299243927
epoch: 4 step: 349, loss is 1.6484581232070923
epoch: 4 step: 350, loss is 1.5608768463134766
epoch: 4 step: 351, loss is 1.558051586151123
epoch: 4 step: 352, loss is 1.6759580373764038
epoch: 4 step: 353, loss is 1.7302846908569336
epoch: 4 step: 354, loss is 1.4885876178741455
epoch: 4 step: 355, loss is 1.6168855428695679
epoch: 4 step: 356, loss is 1.610166311264038
epoch: 4 step: 357, loss is 1.5799460411071777
epoch: 4 step: 358, loss is 1.6324753761291504
epoch: 4 step: 359, loss is 1.6962008476257324
epoch: 4 step: 360, loss is 1.6340652704238892
epoch: 4 step: 361, loss is 1.5412421226501465
epoch: 4 step: 362, loss is 1.6910009384155273
epoch: 4 step: 363, loss is 1.6928589344024658
epoch: 4 step: 364, loss is 1.7076610326766968
epoch: 4 step: 365, loss is 1.5773704051971436
epoch: 4 step: 366, loss is 1.648019790649414
epoch: 4 step: 367, loss is 1.6586183309555054
epoch: 4 step: 368, loss is 1.580578088760376
epoch: 4 step: 369, loss is 1.7289409637451172
epoch: 4 step: 370, loss is 1.5810699462890625
epoch: 4 step: 371, loss is 1.6475943326950073
epoch: 4 step: 372, loss is 1.6045632362365723
epoch: 4 step: 373, loss is 1.579331874847412
epoch: 4 step: 374, loss is 1.5883054733276367
epoch: 4 step: 375, loss is 1.6973354816436768
epoch: 4 step: 376, loss is 1.713724136352539
epoch: 4 step: 377, loss is 1.6705690622329712
epoch: 4 step: 378, loss is 1.6421804428100586
epoch: 4 step: 379, loss is 1.5716965198516846
epoch: 4 step: 380, loss is 1.72585129737854
epoch: 4 step: 381, loss is 1.6475350856781006
epoch: 4 step: 382, loss is 1.5757324695587158
epoch: 4 step: 383, loss is 1.504807710647583
epoch: 4 step: 384, loss is 1.6358526945114136
epoch: 4 step: 385, loss is 1.60054349899292
epoch: 4 step: 386, loss is 1.4559749364852905
epoch: 4 step: 387, loss is 1.6298948526382446
epoch: 4 step: 388, loss is 1.6332250833511353
epoch: 4 step: 389, loss is 1.5376715660095215
epoch: 4 step: 390, loss is 1.6604375839233398
Train epoch time: 148192.920 ms, per step time: 379.982 ms
epoch: 5 step: 1, loss is 1.732064962387085
epoch: 5 step: 2, loss is 1.642157793045044
epoch: 5 step: 3, loss is 1.6185276508331299
epoch: 5 step: 4, loss is 1.5479648113250732
epoch: 5 step: 5, loss is 1.6563830375671387
epoch: 5 step: 6, loss is 1.5758249759674072
epoch: 5 step: 7, loss is 1.5886647701263428
epoch: 5 step: 8, loss is 1.592846155166626
epoch: 5 step: 9, loss is 1.5525658130645752
epoch: 5 step: 10, loss is 1.5996041297912598
epoch: 5 step: 11, loss is 1.6251089572906494
epoch: 5 step: 12, loss is 1.5908703804016113
epoch: 5 step: 13, loss is 1.6498301029205322
epoch: 5 step: 14, loss is 1.5614851713180542
epoch: 5 step: 15, loss is 1.576161503791809
epoch: 5 step: 16, loss is 1.6075465679168701
epoch: 5 step: 17, loss is 1.4701895713806152
epoch: 5 step: 18, loss is 1.6431536674499512
epoch: 5 step: 19, loss is 1.6119287014007568
epoch: 5 step: 20, loss is 1.5730373859405518
epoch: 5 step: 21, loss is 1.6200966835021973
epoch: 5 step: 22, loss is 1.548967957496643
epoch: 5 step: 23, loss is 1.5121467113494873
epoch: 5 step: 24, loss is 1.6271919012069702
epoch: 5 step: 25, loss is 1.5801106691360474
epoch: 5 step: 26, loss is 1.4629850387573242
epoch: 5 step: 27, loss is 1.517683744430542
epoch: 5 step: 28, loss is 1.5276094675064087
epoch: 5 step: 29, loss is 1.6179333925247192
epoch: 5 step: 30, loss is 1.6676299571990967
epoch: 5 step: 31, loss is 1.619720220565796
epoch: 5 step: 32, loss is 1.6356887817382812
epoch: 5 step: 33, loss is 1.47053861618042
epoch: 5 step: 34, loss is 1.5344951152801514
epoch: 5 step: 35, loss is 1.5774224996566772
epoch: 5 step: 36, loss is 1.515333652496338
epoch: 5 step: 37, loss is 1.7566314935684204
epoch: 5 step: 38, loss is 1.6011673212051392
epoch: 5 step: 39, loss is 1.618564248085022
epoch: 5 step: 40, loss is 1.534090280532837
epoch: 5 step: 41, loss is 1.605145812034607
epoch: 5 step: 42, loss is 1.609848141670227
epoch: 5 step: 43, loss is 1.573960542678833
epoch: 5 step: 44, loss is 1.5615041255950928
epoch: 5 step: 45, loss is 1.6148253679275513
epoch: 5 step: 46, loss is 1.6173732280731201
epoch: 5 step: 47, loss is 1.5960638523101807
epoch: 5 step: 48, loss is 1.6351659297943115
epoch: 5 step: 49, loss is 1.4730830192565918
epoch: 5 step: 50, loss is 1.5888168811798096
epoch: 5 step: 51, loss is 1.6218546628952026
epoch: 5 step: 52, loss is 1.6378370523452759
epoch: 5 step: 53, loss is 1.4967222213745117
epoch: 5 step: 54, loss is 1.64961576461792
epoch: 5 step: 55, loss is 1.5266637802124023
epoch: 5 step: 56, loss is 1.560529112815857
epoch: 5 step: 57, loss is 1.689986228942871
epoch: 5 step: 58, loss is 1.591117262840271
epoch: 5 step: 59, loss is 1.608198881149292
epoch: 5 step: 60, loss is 1.563807725906372
epoch: 5 step: 61, loss is 1.6101875305175781
epoch: 5 step: 62, loss is 1.7339825630187988
epoch: 5 step: 63, loss is 1.7158645391464233
epoch: 5 step: 64, loss is 1.5506260395050049
epoch: 5 step: 65, loss is 1.602805733680725
epoch: 5 step: 66, loss is 1.6202452182769775
epoch: 5 step: 67, loss is 1.6035434007644653
epoch: 5 step: 68, loss is 1.4786369800567627
epoch: 5 step: 69, loss is 1.6125078201293945
epoch: 5 step: 70, loss is 1.6711918115615845
epoch: 5 step: 71, loss is 1.6742210388183594
epoch: 5 step: 72, loss is 1.7218670845031738
epoch: 5 step: 73, loss is 1.5226973295211792
epoch: 5 step: 74, loss is 1.677172064781189
epoch: 5 step: 75, loss is 1.5475211143493652
epoch: 5 step: 76, loss is 1.7108920812606812
epoch: 5 step: 77, loss is 1.6100984811782837
epoch: 5 step: 78, loss is 1.6316590309143066
epoch: 5 step: 79, loss is 1.5527653694152832
epoch: 5 step: 80, loss is 1.596567153930664
epoch: 5 step: 81, loss is 1.624192714691162
epoch: 5 step: 82, loss is 1.802162766456604
epoch: 5 step: 83, loss is 1.5259840488433838
epoch: 5 step: 84, loss is 1.509135127067566
epoch: 5 step: 85, loss is 1.6101211309432983
epoch: 5 step: 86, loss is 1.72317636013031
epoch: 5 step: 87, loss is 1.4869513511657715
epoch: 5 step: 88, loss is 1.5524060726165771
epoch: 5 step: 89, loss is 1.6202787160873413
epoch: 5 step: 90, loss is 1.7478114366531372
epoch: 5 step: 91, loss is 1.6020302772521973
epoch: 5 step: 92, loss is 1.606710433959961
epoch: 5 step: 93, loss is 1.6935527324676514
epoch: 5 step: 94, loss is 1.6623975038528442
epoch: 5 step: 95, loss is 1.6251224279403687
epoch: 5 step: 96, loss is 1.6424787044525146
epoch: 5 step: 97, loss is 1.6513460874557495
epoch: 5 step: 98, loss is 1.6484696865081787
epoch: 5 step: 99, loss is 1.5932360887527466
epoch: 5 step: 100, loss is 1.6748833656311035
epoch: 5 step: 101, loss is 1.531472086906433
epoch: 5 step: 102, loss is 1.5079199075698853
epoch: 5 step: 103, loss is 1.5829377174377441
epoch: 5 step: 104, loss is 1.7070366144180298
epoch: 5 step: 105, loss is 1.6775567531585693
epoch: 5 step: 106, loss is 1.733814001083374
epoch: 5 step: 107, loss is 1.6535528898239136
epoch: 5 step: 108, loss is 1.6326770782470703
epoch: 5 step: 109, loss is 1.6118979454040527
epoch: 5 step: 110, loss is 1.6682919263839722
epoch: 5 step: 111, loss is 1.6229664087295532
epoch: 5 step: 112, loss is 1.6211090087890625
epoch: 5 step: 113, loss is 1.583876609802246
epoch: 5 step: 114, loss is 1.5698292255401611
epoch: 5 step: 115, loss is 1.45005464553833
epoch: 5 step: 116, loss is 1.685972809791565
epoch: 5 step: 117, loss is 1.730628252029419
epoch: 5 step: 118, loss is 1.620910882949829
epoch: 5 step: 119, loss is 1.7291038036346436
epoch: 5 step: 120, loss is 1.617356300354004
epoch: 5 step: 121, loss is 1.5933917760849
epoch: 5 step: 122, loss is 1.6709065437316895
epoch: 5 step: 123, loss is 1.4122443199157715
epoch: 5 step: 124, loss is 1.5545268058776855
epoch: 5 step: 125, loss is 1.7473828792572021
epoch: 5 step: 126, loss is 1.6225085258483887
epoch: 5 step: 127, loss is 1.7124907970428467
epoch: 5 step: 128, loss is 1.5947970151901245
epoch: 5 step: 129, loss is 1.6763713359832764
epoch: 5 step: 130, loss is 1.7047843933105469
epoch: 5 step: 131, loss is 1.6374223232269287
epoch: 5 step: 132, loss is 1.6349091529846191
epoch: 5 step: 133, loss is 1.5605924129486084
epoch: 5 step: 134, loss is 1.6401188373565674
epoch: 5 step: 135, loss is 1.590761423110962
epoch: 5 step: 136, loss is 1.7277530431747437
epoch: 5 step: 137, loss is 1.6104816198349
epoch: 5 step: 138, loss is 1.6965982913970947
epoch: 5 step: 139, loss is 1.5691813230514526
epoch: 5 step: 140, loss is 1.7318273782730103
epoch: 5 step: 141, loss is 1.5476237535476685
epoch: 5 step: 142, loss is 1.5739657878875732
epoch: 5 step: 143, loss is 1.6846758127212524
epoch: 5 step: 144, loss is 1.5783294439315796
epoch: 5 step: 145, loss is 1.6351780891418457
epoch: 5 step: 146, loss is 1.6586753129959106
epoch: 5 step: 147, loss is 1.4739278554916382
epoch: 5 step: 148, loss is 1.562909722328186
epoch: 5 step: 149, loss is 1.7065060138702393
epoch: 5 step: 150, loss is 1.7247668504714966
epoch: 5 step: 151, loss is 1.612099528312683
epoch: 5 step: 152, loss is 1.6697736978530884
epoch: 5 step: 153, loss is 1.5842936038970947
epoch: 5 step: 154, loss is 1.6145288944244385
epoch: 5 step: 155, loss is 1.64228355884552
epoch: 5 step: 156, loss is 1.6170909404754639
epoch: 5 step: 157, loss is 1.6164065599441528
epoch: 5 step: 158, loss is 1.6436498165130615
epoch: 5 step: 159, loss is 1.6956052780151367
epoch: 5 step: 160, loss is 1.644710898399353
epoch: 5 step: 161, loss is 1.6524138450622559
epoch: 5 step: 162, loss is 1.6974036693572998
epoch: 5 step: 163, loss is 1.6052260398864746
epoch: 5 step: 164, loss is 1.478930950164795
epoch: 5 step: 165, loss is 1.6300386190414429
epoch: 5 step: 166, loss is 1.6599113941192627
epoch: 5 step: 167, loss is 1.642054557800293
epoch: 5 step: 168, loss is 1.743283748626709
epoch: 5 step: 169, loss is 1.4537618160247803
epoch: 5 step: 170, loss is 1.6595747470855713
epoch: 5 step: 171, loss is 1.6030980348587036
epoch: 5 step: 172, loss is 1.6536245346069336
epoch: 5 step: 173, loss is 1.5781941413879395
epoch: 5 step: 174, loss is 1.5839710235595703
epoch: 5 step: 175, loss is 1.610130786895752
epoch: 5 step: 176, loss is 1.5927903652191162
epoch: 5 step: 177, loss is 1.5741008520126343
epoch: 5 step: 178, loss is 1.7687041759490967
epoch: 5 step: 179, loss is 1.6198071241378784
epoch: 5 step: 180, loss is 1.5724008083343506
epoch: 5 step: 181, loss is 1.7192161083221436
epoch: 5 step: 182, loss is 1.6766963005065918
epoch: 5 step: 183, loss is 1.5468106269836426
epoch: 5 step: 184, loss is 1.6617190837860107
epoch: 5 step: 185, loss is 1.654116153717041
epoch: 5 step: 186, loss is 1.6839861869812012
epoch: 5 step: 187, loss is 1.6468552350997925
epoch: 5 step: 188, loss is 1.5491598844528198
epoch: 5 step: 189, loss is 1.6253888607025146
epoch: 5 step: 190, loss is 1.746768832206726
epoch: 5 step: 191, loss is 1.6072258949279785
epoch: 5 step: 192, loss is 1.5063557624816895
epoch: 5 step: 193, loss is 1.6080549955368042
epoch: 5 step: 194, loss is 1.4681304693222046
epoch: 5 step: 195, loss is 1.7305692434310913
epoch: 5 step: 196, loss is 1.567294716835022
epoch: 5 step: 197, loss is 1.6033329963684082
epoch: 5 step: 198, loss is 1.658197283744812
epoch: 5 step: 199, loss is 1.5383046865463257
epoch: 5 step: 200, loss is 1.5731761455535889
epoch: 5 step: 201, loss is 1.5788071155548096
epoch: 5 step: 202, loss is 1.5512884855270386
epoch: 5 step: 203, loss is 1.6297571659088135
epoch: 5 step: 204, loss is 1.5621063709259033
epoch: 5 step: 205, loss is 1.6575044393539429
epoch: 5 step: 206, loss is 1.717639684677124
epoch: 5 step: 207, loss is 1.6056252717971802
epoch: 5 step: 208, loss is 1.5446138381958008
epoch: 5 step: 209, loss is 1.6509729623794556
epoch: 5 step: 210, loss is 1.525032639503479
epoch: 5 step: 211, loss is 1.6058061122894287
epoch: 5 step: 212, loss is 1.6380984783172607
epoch: 5 step: 213, loss is 1.5760167837142944
epoch: 5 step: 214, loss is 1.616842269897461
epoch: 5 step: 215, loss is 1.6996495723724365
epoch: 5 step: 216, loss is 1.5389083623886108
epoch: 5 step: 217, loss is 1.5867993831634521
epoch: 5 step: 218, loss is 1.4464359283447266
epoch: 5 step: 219, loss is 1.6728496551513672
epoch: 5 step: 220, loss is 1.6889090538024902
epoch: 5 step: 221, loss is 1.6454179286956787
epoch: 5 step: 222, loss is 1.5812941789627075
epoch: 5 step: 223, loss is 1.5781618356704712
epoch: 5 step: 224, loss is 1.527899146080017
epoch: 5 step: 225, loss is 1.7462377548217773
epoch: 5 step: 226, loss is 1.6562824249267578
epoch: 5 step: 227, loss is 1.5749261379241943
epoch: 5 step: 228, loss is 1.5579584836959839
epoch: 5 step: 229, loss is 1.605156421661377
epoch: 5 step: 230, loss is 1.5233993530273438
epoch: 5 step: 231, loss is 1.6699270009994507
epoch: 5 step: 232, loss is 1.6647980213165283
epoch: 5 step: 233, loss is 1.548474907875061
epoch: 5 step: 234, loss is 1.4801830053329468
epoch: 5 step: 235, loss is 1.6232141256332397
epoch: 5 step: 236, loss is 1.5874500274658203
epoch: 5 step: 237, loss is 1.476054310798645
epoch: 5 step: 238, loss is 1.724316120147705
epoch: 5 step: 239, loss is 1.603211760520935
epoch: 5 step: 240, loss is 1.5714964866638184
epoch: 5 step: 241, loss is 1.5042778253555298
epoch: 5 step: 242, loss is 1.6531548500061035
epoch: 5 step: 243, loss is 1.6012554168701172
epoch: 5 step: 244, loss is 1.5364093780517578
epoch: 5 step: 245, loss is 1.6243541240692139
epoch: 5 step: 246, loss is 1.4586551189422607
epoch: 5 step: 247, loss is 1.5070651769638062
epoch: 5 step: 248, loss is 1.5355825424194336
epoch: 5 step: 249, loss is 1.547766923904419
epoch: 5 step: 250, loss is 1.611509084701538
epoch: 5 step: 251, loss is 1.6216280460357666
epoch: 5 step: 252, loss is 1.6462599039077759
epoch: 5 step: 253, loss is 1.5645804405212402
epoch: 5 step: 254, loss is 1.7160987854003906
epoch: 5 step: 255, loss is 1.7854816913604736
epoch: 5 step: 256, loss is 1.5315474271774292
epoch: 5 step: 257, loss is 1.604637861251831
epoch: 5 step: 258, loss is 1.5908942222595215
epoch: 5 step: 259, loss is 1.5725393295288086
epoch: 5 step: 260, loss is 1.6587775945663452
epoch: 5 step: 261, loss is 1.6050629615783691
epoch: 5 step: 262, loss is 1.4947434663772583
epoch: 5 step: 263, loss is 1.6716800928115845
epoch: 5 step: 264, loss is 1.580312728881836
epoch: 5 step: 265, loss is 1.6166272163391113
epoch: 5 step: 266, loss is 1.4542531967163086
epoch: 5 step: 267, loss is 1.6172029972076416
epoch: 5 step: 268, loss is 1.608907699584961
epoch: 5 step: 269, loss is 1.6537593603134155
epoch: 5 step: 270, loss is 1.5340898036956787
epoch: 5 step: 271, loss is 1.4851632118225098
epoch: 5 step: 272, loss is 1.7174017429351807
epoch: 5 step: 273, loss is 1.7397809028625488
epoch: 5 step: 274, loss is 1.6069607734680176
epoch: 5 step: 275, loss is 1.8224904537200928
epoch: 5 step: 276, loss is 1.5527863502502441
epoch: 5 step: 277, loss is 1.5776571035385132
epoch: 5 step: 278, loss is 1.6275925636291504
epoch: 5 step: 279, loss is 1.6205875873565674
epoch: 5 step: 280, loss is 1.5229628086090088
epoch: 5 step: 281, loss is 1.5539497137069702
epoch: 5 step: 282, loss is 1.6306214332580566
epoch: 5 step: 283, loss is 1.6712145805358887
epoch: 5 step: 284, loss is 1.501929521560669
epoch: 5 step: 285, loss is 1.6607998609542847
epoch: 5 step: 286, loss is 1.4067232608795166
epoch: 5 step: 287, loss is 1.552455186843872
epoch: 5 step: 288, loss is 1.6655046939849854
epoch: 5 step: 289, loss is 1.6979361772537231
epoch: 5 step: 290, loss is 1.6408698558807373
epoch: 5 step: 291, loss is 1.5328257083892822
epoch: 5 step: 292, loss is 1.6922451257705688
epoch: 5 step: 293, loss is 1.5281833410263062
epoch: 5 step: 294, loss is 1.6375106573104858
epoch: 5 step: 295, loss is 1.584839105606079
epoch: 5 step: 296, loss is 1.6950011253356934
epoch: 5 step: 297, loss is 1.6538549661636353
epoch: 5 step: 298, loss is 1.6639721393585205
epoch: 5 step: 299, loss is 1.5627961158752441
epoch: 5 step: 300, loss is 1.6030569076538086
epoch: 5 step: 301, loss is 1.4860504865646362
epoch: 5 step: 302, loss is 1.588396668434143
epoch: 5 step: 303, loss is 1.4862431287765503
epoch: 5 step: 304, loss is 1.5940279960632324
epoch: 5 step: 305, loss is 1.4918596744537354
epoch: 5 step: 306, loss is 1.5705978870391846
epoch: 5 step: 307, loss is 1.5766692161560059
epoch: 5 step: 308, loss is 1.6006094217300415
epoch: 5 step: 309, loss is 1.6463922262191772
epoch: 5 step: 310, loss is 1.5315070152282715
epoch: 5 step: 311, loss is 1.5608091354370117
epoch: 5 step: 312, loss is 1.6662434339523315
epoch: 5 step: 313, loss is 1.5858619213104248
epoch: 5 step: 314, loss is 1.556136131286621
epoch: 5 step: 315, loss is 1.7142648696899414
epoch: 5 step: 316, loss is 1.5239651203155518
epoch: 5 step: 317, loss is 1.6485443115234375
epoch: 5 step: 318, loss is 1.5584352016448975
epoch: 5 step: 319, loss is 1.5999035835266113
epoch: 5 step: 320, loss is 1.6045048236846924
epoch: 5 step: 321, loss is 1.6749330759048462
epoch: 5 step: 322, loss is 1.5389548540115356
epoch: 5 step: 323, loss is 1.6875205039978027
epoch: 5 step: 324, loss is 1.7266160249710083
epoch: 5 step: 325, loss is 1.505654215812683
epoch: 5 step: 326, loss is 1.5563000440597534
epoch: 5 step: 327, loss is 1.5665528774261475
epoch: 5 step: 328, loss is 1.7478028535842896
epoch: 5 step: 329, loss is 1.533534288406372
epoch: 5 step: 330, loss is 1.5377097129821777
epoch: 5 step: 331, loss is 1.5768988132476807
epoch: 5 step: 332, loss is 1.5625202655792236
epoch: 5 step: 333, loss is 1.4823535680770874
epoch: 5 step: 334, loss is 1.472485899925232
epoch: 5 step: 335, loss is 1.6301203966140747
epoch: 5 step: 336, loss is 1.5757269859313965
epoch: 5 step: 337, loss is 1.5469939708709717
epoch: 5 step: 338, loss is 1.6088483333587646
epoch: 5 step: 339, loss is 1.617684245109558
epoch: 5 step: 340, loss is 1.5175946950912476
epoch: 5 step: 341, loss is 1.6401379108428955
epoch: 5 step: 342, loss is 1.5887030363082886
epoch: 5 step: 343, loss is 1.534440279006958
epoch: 5 step: 344, loss is 1.752575159072876
epoch: 5 step: 345, loss is 1.4764305353164673
epoch: 5 step: 346, loss is 1.5502147674560547
epoch: 5 step: 347, loss is 1.670060634613037
epoch: 5 step: 348, loss is 1.7078402042388916
epoch: 5 step: 349, loss is 1.595656156539917
epoch: 5 step: 350, loss is 1.6205188035964966
epoch: 5 step: 351, loss is 1.6061155796051025
epoch: 5 step: 352, loss is 1.4948123693466187
epoch: 5 step: 353, loss is 1.5117608308792114
epoch: 5 step: 354, loss is 1.6266214847564697
epoch: 5 step: 355, loss is 1.604814887046814
epoch: 5 step: 356, loss is 1.588590383529663
epoch: 5 step: 357, loss is 1.5971734523773193
epoch: 5 step: 358, loss is 1.49770987033844
epoch: 5 step: 359, loss is 1.5792601108551025
epoch: 5 step: 360, loss is 1.5745766162872314
epoch: 5 step: 361, loss is 1.5864574909210205
epoch: 5 step: 362, loss is 1.661876916885376
epoch: 5 step: 363, loss is 1.6659116744995117
epoch: 5 step: 364, loss is 1.5186996459960938
epoch: 5 step: 365, loss is 1.5716300010681152
epoch: 5 step: 366, loss is 1.5549710988998413
epoch: 5 step: 367, loss is 1.5875260829925537
epoch: 5 step: 368, loss is 1.6119790077209473
epoch: 5 step: 369, loss is 1.6168357133865356
epoch: 5 step: 370, loss is 1.651954174041748
epoch: 5 step: 371, loss is 1.5352210998535156
epoch: 5 step: 372, loss is 1.681986689567566
epoch: 5 step: 373, loss is 1.5233029127120972
epoch: 5 step: 374, loss is 1.5772886276245117
epoch: 5 step: 375, loss is 1.5943220853805542
epoch: 5 step: 376, loss is 1.6477620601654053
epoch: 5 step: 377, loss is 1.642315149307251
epoch: 5 step: 378, loss is 1.4545228481292725
epoch: 5 step: 379, loss is 1.6743626594543457
epoch: 5 step: 380, loss is 1.5327942371368408
epoch: 5 step: 381, loss is 1.5334811210632324
epoch: 5 step: 382, loss is 1.5805553197860718
epoch: 5 step: 383, loss is 1.5866146087646484
epoch: 5 step: 384, loss is 1.648356318473816
epoch: 5 step: 385, loss is 1.6265746355056763
epoch: 5 step: 386, loss is 1.4873814582824707
epoch: 5 step: 387, loss is 1.589070439338684
epoch: 5 step: 388, loss is 1.5715038776397705
epoch: 5 step: 389, loss is 1.6626286506652832
epoch: 5 step: 390, loss is 1.5729644298553467
Train epoch time: 138221.297 ms, per step time: 354.414 ms
total time:0h 18m 5s
============== Train Success ==============
训练好的模型保存在当前目录的shufflenetv1-5_390.ckpt
中,用作评估。
模型评估
在CIFAR-10的测试集上对模型进行评估。
设置好评估模型的路径后加载数据集,并设置Top 1, Top 5的评估标准,最后用model.eval()
接口对模型进行评估。
from mindspore import load_checkpoint, load_param_into_net
def test():
mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
net = ShuffleNetV1(model_size="2.0x", n_class=10)
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
net.set_train(False)
loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
'Top_5_Acc': Top5CategoricalAccuracy()}
model = Model(net, loss_fn=loss, metrics=eval_metrics)
start_time = time.time()
res = model.eval(dataset, dataset_sink_mode=False)
use_time = time.time() - start_time
hour = str(int(use_time // 60 // 60))
minute = str(int(use_time // 60 % 60))
second = str(int(use_time % 60))
log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
+ "', time: " + hour + "h " + minute + "m " + second + "s"
print(log)
filename = './eval_log.txt'
with open(filename, 'a') as file_object:
file_object.write(log + '\n')
if __name__ == '__main__':
test()
model size is 2.0x
result:{'Loss': 1.6017735569905012, 'Top_1_Acc': 0.49849759615384615, 'Top_5_Acc': 0.9365985576923077}, ckpt:'./shufflenetv1-5_390.ckpt', time: 0h 1m 28s
模型预测
在CIFAR-10的测试集上对模型进行预测,并将预测结果可视化。
import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as ds
net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Resize((224, 224)),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
vision.HWC2CHW()
]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:
plt.subplot(2, 8, index+1)
plt.title('{}'.format(class_dict[pred[index]]))
index += 1
plt.imshow(image)
plt.axis("off")
plt.show()
model size is 2.0x
from datetime import datetime
import pytz
# 设置时区为北京时区
beijing_tz = pytz.timezone('Asia/shanghai')
# 获取当前时间,并转为北京时间
current_beijing_time = datetime.now(beijing_tz)
# 格式化时间输出
formatted_time = current_beijing_time.strftime('%Y-%m-%d %H:%M:%S')
print("当前北京时间:",formatted_time)
print('用户名:matpandas 显似')
当前北京时间: 2024-07-21 15:20:08
用户名:matpandas 显似