目标检测基础
边界框
def bbox_to_rect(bbox, color): # 本函数已保存在d2lzh_pytorch中方便以后使用
# 将边界框(左上x, 左上y, 右下x, 右下y)-->((左上x, 左上y), 宽, 高)
return d2l.plt.Rectangle(
锚框
生成多个锚框
假设输入图像高为
h
h
h,宽为
w
w
w。我们分别以图像的每个像素为中心生成不同形状的锚框。设大小为
s
∈
(
0
,
1
]
s\in (0,1]
s∈(0,1]且宽高比为r > 0,那么锚框的宽和高将分别为
w
s
r
ws\sqrt{r}
wsr和
h
s
/
r
hs/\sqrt{r}
hs/r。当中心位置给定时,已知宽和高的锚框是确定的。
我们通常只对包含
s
1
s_1
s1或
r
1
r_1
r1的大小与宽高比的组合感兴趣,即
(
s
1
,
r
1
)
,
(
s
1
,
r
2
)
,
…
,
(
s
1
,
r
m
)
,
(
s
2
,
r
1
)
,
(
s
3
,
r
1
)
,
…
,
(
s
n
,
r
1
)
.
(s_1, r_1), (s_1, r_2), \ldots, (s_1, r_m), (s_2, r_1), (s_3, r_1), \ldots, (s_n, r_1).
(s1,r1),(s1,r2),…,(s1,rm),(s2,r1),(s3,r1),…,(sn,r1).
共生成
w
h
(
n
+
m
−
1
)
wh(n+m-1)
wh(n+m−1)个锚框。
def MultiBoxPrior(feature_map, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5]):
pairs = [] # pair of (size, sqrt(ration))
# 生成n + m -1个框(所有可能的s、r对)
for r in ratios:
pairs.append([sizes[0], math.sqrt(r)])
for s in sizes[1:]:
pairs.append([s, math.sqrt(ratios[0])])
pairs = np.array(pairs)
# 生成相对于坐标中心点的框(x,y,x,y)
ss1 = pairs[:, 0] * pairs[:, 1] # size * sqrt(ration) # wsr^1/2
ss2 = pairs[:, 0] / pairs[:, 1] # size / sqrt(ration) # wh/r^1/2
base_anchors = np.stack([-ss1, -ss2, ss1, ss2], axis=1) / 2
#将坐标点和anchor组合起来生成hw(n+m-1)个框输出
h, w = feature_map.shape[-2:]
shifts_x = np.arange(0, w) / w # 按比例缩放规范到0-1之间
shifts_y = np.arange(0, h) / h # 按比例缩放规范到0-1之间
shift_x, shift_y = np.meshgrid(shifts_x, shifts_y) # 所有可能的中心像素点
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = np.stack((shift_x, shift_y, shift_x, shift_y), axis=1)
anchors = shifts.reshape((-1, 1, 4)) + base_anchors.reshape((1, -1, 4)) # n*(左上x,左上y,右下x,右下y)
return torch.tensor(anchors, dtype=torch.float32).view(1, -1, 4)
def show_bboxes(axes, bboxes, labels=None, colors=None): # 画出指定锚框
def _make_list(obj, default_values=None):
if obj is None:
obj = default_values
elif not isinstance(obj, (list, tuple)):
obj = [obj]
return obj
labels = _make_list(labels)
colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])
for i, bbox in enumerate(bboxes):
color = colors[i % len(colors)]
rect = d2l.bbox_to_rect(bbox.detach().cpu().numpy(), color)
axes.add_patch(rect)
if labels and len(labels) > i:
text_color = 'k' if color == 'w' else 'w'
axes.text(rect.xy[0], rect.xy[1], labels[i],
va='center', ha='center', fontsize=6, color=text_color,
bbox=dict(facecolor=color, lw=0))
交并比
Jaccard系数(Jaccard index)可以衡量两个集合的相似度。给定集合
A
\mathcal{A}
A和
B
\mathcal{B}
B,它们的Jaccard系数即二者交集大小除以二者并集大小:
J
(
A
,
B
)
=
∣
A
∩
B
∣
∣
A
∪
B
∣
.
J(\mathcal{A},\mathcal{B}) = \frac{\left|\mathcal{A} \cap \mathcal{B}\right|}{\left| \mathcal{A} \cup \mathcal{B}\right|}.
J(A,B)=∣A∪B∣∣A∩B∣.
def compute_intersection(set_1, set_2):
# PyTorch auto-broadcasts singleton dimensions
lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2)
upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2)
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2)
return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
def compute_jaccard(set_1, set_2):
intersection = compute_intersection(set_1, set_2) # (n1, n2)
areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1]) # (n1)
areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1]) # (n2)
# Find the union
# PyTorch auto-broadcasts singleton dimensions
union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2)
return intersection / union # (n1, n2)
标注训练集的锚框
为锚框分配与其相似的真实边界框:
标注
A
A
A的偏移量:
(
x
b
−
x
a
w
a
−
μ
x
σ
x
,
y
b
−
y
a
h
a
−
μ
y
σ
y
,
log
w
b
w
a
−
μ
w
σ
w
,
log
h
b
h
a
−
μ
h
σ
h
)
,
\left( \frac{ \frac{x_b - x_a}{w_a} - \mu_x }{\sigma_x}, \frac{ \frac{y_b - y_a}{h_a} - \mu_y }{\sigma_y}, \frac{ \log \frac{w_b}{w_a} - \mu_w }{\sigma_w}, \frac{ \log \frac{h_b}{h_a} - \mu_h }{\sigma_h}\right),
(σxwaxb−xa−μx,σyhayb−ya−μy,σwlogwawb−μw,σhloghahb−μh),
其中常数的默认值为
μ
x
=
μ
y
=
μ
w
=
μ
h
=
0
,
σ
x
=
σ
y
=
0.1
,
σ
w
=
σ
h
=
0.2
\mu_x = \mu_y = \mu_w = \mu_h = 0, \sigma_x=\sigma_y=0.1, \sigma_w=\sigma_h=0.2
μx=μy=μw=μh=0,σx=σy=0.1,σw=σh=0.2
标注类别和偏移量,背景类别设为0,1为狗,2为猫:
def assign_anchor(bb, anchor, jaccard_threshold=0.5):
na = anchor.shape[0] # 锚框数
nb = bb.shape[0] # 真实框数
jaccard = compute_jaccard(anchor, bb).detach().cpu().numpy() # shape: (na, nb)
assigned_idx = np.ones(na) * -1 # 存放标签初始全为-1
# 先为每个bb分配一个anchor(不要求满足jaccard_threshold)
jaccard_cp = jaccard.copy()
for j in range(nb):
i = np.argmax(jaccard_cp[:, j])
assigned_idx[i] = j
jaccard_cp[i, :] = float("-inf") # 赋值为负无穷, 相当于去掉这一行
# 处理还未被分配的anchor, 要求满足jaccard_threshold
for i in range(na):
if assigned_idx[i] == -1:
j = np.argmax(jaccard[i, :])
if jaccard[i, j] >= jaccard_threshold:
assigned_idx[i] = j
return torch.tensor(assigned_idx, dtype=torch.long)
def xy_to_cxcy(xy):
# 将(x_min, y_min, x_max, y_max)形式的anchor转换成(center_x, center_y, w, h)形式的.
return torch.cat([(xy[:, 2:] + xy[:, :2]) / 2, # c_x, c_y
xy[:, 2:] - xy[:, :2]], 1) # w, h
def MultiBoxTarget(anchor, label):
"""
# 按照「9.4.1. 生成多个锚框」所讲的实现, anchor表示成归一化(xmin, ymin, xmax, ymax).
Args:
anchor: torch tensor, 输入的锚框, 一般是通过MultiBoxPrior生成, shape:(1,锚框总数,4)
label: 真实标签, shape为(bn, 每张图片最多的真实锚框数, 5)
第二维中,如果给定图片没有这么多锚框, 可以先用-1填充空白, 最后一维中的元素为[类别标签, 四个坐标值]
Returns:
列表, [bbox_offset, bbox_mask, cls_labels]
bbox_offset: 每个锚框的标注偏移量,形状为(bn,锚框总数*4)
bbox_mask: 形状同bbox_offset, 每个锚框的掩码, 一一对应上面的偏移量, 负类锚框(背景)对应的掩码均为0, 正类锚框的掩码均为1
cls_labels: 每个锚框的标注类别, 其中0表示为背景, 形状为(bn,锚框总数)
"""
assert len(anchor.shape) == 3 and len(label.shape) == 3
bn = label.shape[0] # 多少batch,即多少张图片,一张张处理
def MultiBoxTarget_one(anc, lab, eps=1e-6): # 处理一个batch,即单张图片
"""
MultiBoxTarget函数的辅助函数, 处理batch中的一个
Args:
anc: shape of (锚框总数, 4)
lab: shape of (真实锚框数, 5), 5代表[类别标签, 四个坐标值]
eps: 一个极小值, 防止log0
Returns:
offset: (锚框总数*4, )
bbox_mask: (锚框总数*4, ), 0代表背景, 1代表非背景
cls_labels: (锚框总数, 4), 0代表背景
"""
an = anc.shape[0]
# 变量的意义
assigned_idx = assign_anchor(lab[:, 1:], anc) # (锚框总数, ) 为每个锚框分配真实锚框
print("a: ", assigned_idx.shape)
print(assigned_idx)
bbox_mask = ((assigned_idx >= 0).float().unsqueeze(-1)).repeat(1, 4) # (锚框总数, 4) #这里背景不是0,是-1.所以是取非背景框
print("b: " , bbox_mask.shape)
print(bbox_mask)
cls_labels = torch.zeros(an, dtype=torch.long) # 记录每个锚框分配的标签,0表示背景
assigned_bb = torch.zeros((an, 4), dtype=torch.float32) # 记录每个锚框分配的真实锚框的坐标
for i in range(an):
bb_idx = assigned_idx[i] # 对应的真实框序号
if bb_idx >= 0: # 即非背景
cls_labels[i] = lab[bb_idx, 0].long().item() + 1 # 记录对应真实框类别,注意要加一
assigned_bb[i, :] = lab[bb_idx, 1:]
# 如何计算偏移量
center_anc = xy_to_cxcy(anc) # (center_x, center_y, w, h)
center_assigned_bb = xy_to_cxcy(assigned_bb)
offset_xy = 10.0 * (center_assigned_bb[:, :2] - center_anc[:, :2]) / center_anc[:, 2:]
offset_wh = 5.0 * torch.log(eps + center_assigned_bb[:, 2:] / center_anc[:, 2:])
offset = torch.cat([offset_xy, offset_wh], dim = 1) * bbox_mask # (锚框总数, 4)
return offset.view(-1), bbox_mask.view(-1), cls_labels
# 组合输出
batch_offset = []
batch_mask = []
batch_cls_labels = []
for b in range(bn): # 一张一张图片处理
offset, bbox_mask, cls_labels = MultiBoxTarget_one(anchor[0, :, :], label[b, :, :])
batch_offset.append(offset)
batch_mask.append(bbox_mask)
batch_cls_labels.append(cls_labels)
bbox_offset = torch.stack(batch_offset)
bbox_mask = torch.stack(batch_mask)
cls_labels = torch.stack(batch_cls_labels)
return [bbox_offset, bbox_mask, cls_labels]
输出预测边界框(非极大值抑制)
采用非极大值抑制(non-maximum suppression,NMS),移除相似的预测边界框。
from collections import namedtuple
Pred_BB_Info = namedtuple("Pred_BB_Info", ["index", "class_id", "confidence", "xyxy"])
def non_max_suppression(bb_info_list, nms_threshold = 0.5):
"""
非极大抑制处理预测的边界框
Args:
bb_info_list: Pred_BB_Info的列表, 包含预测类别、置信度等信息
nms_threshold: 阈值
Returns:
output: Pred_BB_Info的列表, 只保留过滤后的边界框信息
"""
output = []
# 先根据置信度从高到低排序
sorted_bb_info_list = sorted(bb_info_list, key = lambda x: x.confidence, reverse=True)
# 循环遍历删除冗余输出
while len(sorted_bb_info_list) != 0:
best = sorted_bb_info_list.pop(0)
output.append(best)
if len(sorted_bb_info_list) == 0:
break
bb_xyxy = []
for bb in sorted_bb_info_list:
bb_xyxy.append(bb.xyxy)
iou = compute_jaccard(torch.tensor([best.xyxy]),
torch.tensor(bb_xyxy))[0] # shape: (len(sorted_bb_info_list), )
n = len(sorted_bb_info_list)
sorted_bb_info_list = [sorted_bb_info_list[i] for i in range(n) if iou[i] <= nms_threshold]
return output
def MultiBoxDetection(cls_prob, loc_pred, anchor, nms_threshold = 0.5):
"""
# 按照「9.4.1. 生成多个锚框」所讲的实现, anchor表示成归一化(xmin, ymin, xmax, ymax).
Args:
cls_prob: 经过softmax后得到的各个锚框的预测概率, shape:(bn, 预测总类别数+1, 锚框个数)
loc_pred: 预测的各个锚框的偏移量, shape:(bn, 锚框个数*4)
anchor: MultiBoxPrior输出的默认锚框, shape: (1, 锚框个数, 4)
nms_threshold: 非极大抑制中的阈值
Returns:
所有锚框的信息, shape: (bn, 锚框个数, 6)
每个锚框信息由[class_id, confidence, xmin, ymin, xmax, ymax]表示
class_id=-1 表示背景或在非极大值抑制中被移除了
"""
assert len(cls_prob.shape) == 3 and len(loc_pred.shape) == 2 and len(anchor.shape) == 3
bn = cls_prob.shape[0]
def MultiBoxDetection_one(c_p, l_p, anc, nms_threshold = 0.5):
"""
MultiBoxDetection的辅助函数, 处理batch中的一个
Args:
c_p: (预测总类别数+1, 锚框个数)
l_p: (锚框个数*4, )
anc: (锚框个数, 4)
nms_threshold: 非极大抑制中的阈值
Return:
output: (锚框个数, 6)
"""
pred_bb_num = c_p.shape[1] # 锚框个数
anc = (anc + l_p.view(pred_bb_num, 4)).detach().cpu().numpy() # 加上偏移量
confidence, class_id = torch.max(c_p, 0) # 最大置信,最大置信对应类
confidence = confidence.detach().cpu().numpy()
class_id = class_id.detach().cpu().numpy()
pred_bb_info = [Pred_BB_Info(
index = i, # 第i个
class_id = class_id[i] - 1, # 正类label从0开始,第i类
confidence = confidence[i], # 置信度
xyxy=[*anc[i]]) # xyxy是个列表 # anc+偏移量
for i in range(pred_bb_num)]
# 正类的index
obj_bb_idx = [bb.index for bb in non_max_suppression(pred_bb_info, nms_threshold)] # 非极大值抑制
output = []
for bb in pred_bb_info:
output.append([
(bb.class_id if bb.index in obj_bb_idx else -1.0),
bb.confidence,
*bb.xyxy
])
return torch.tensor(output) # shape: (锚框个数, 6)
batch_output = []
for b in range(bn):
batch_output.append(MultiBoxDetection_one(cls_prob[b], loc_pred[b], anchor[0], nms_threshold))
return torch.stack(batch_output)
多尺度目标检测
减少锚框个数并不难。一种简单的方法是在输入图像中均匀采样一小部分像素,并以采样的像素为中心生成锚框。此外,在不同尺度下,我们可以生成不同数量和不同大小的锚框。
当使用较小锚框来检测较小目标时,我们可以采样较多的区域;而当使用较大锚框来检测较大目标时,我们可以采样较少的区域。
def display_anchors(fmap_w, fmap_h, s):
# 前两维的取值不影响输出结果(原书这里是(1, 10, fmap_w, fmap_h), 我认为错了)
fmap = torch.zeros((1, 10, fmap_h, fmap_w), dtype=torch.float32)
# 平移所有锚框使均匀分布在图片上
offset_x, offset_y = 1.0/fmap_w, 1.0/fmap_h
anchors = d2l.MultiBoxPrior(fmap, sizes=s, ratios=[1, 2, 0.5]) + \
torch.tensor([offset_x/2, offset_y/2, offset_x/2, offset_y/2]) # 可能锚框loc加上偏移量
bbox_scale = torch.tensor([[w, h, w, h]], dtype=torch.float32) # 缩放至正常尺寸
d2l.show_bboxes(d2l.plt.imshow(img).axes,
anchors[0] * bbox_scale)
display_anchors(fmap_w=4, fmap_h=2, s=[0.15])
图像风格迁移
方法
1.首先,我们初始化合成图像,例如将其初始化成内容图像。该合成图像是样式迁移过程中唯一需要更新的变量,即样式迁移所需迭代的模型参数。
2.然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。深度卷积神经网络凭借多个层逐级抽取图像的特征。我们可以选择其中某些层的输出作为内容特征或样式特征。以图9.13为例,这里选取的预训练的神经网络含有3个卷积层,其中第二层输出图像的内容特征,而第一层和第三层的输出被作为图像的样式特征。接下来,我们通过正向传播(实线箭头方向)计算样式迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像。
3.样式迁移常用的损失函数由3部分组成:内容损失(content loss)使合成图像与内容图像在内容特征上接近,样式损失(style loss)令合成图像与样式图像在样式特征上接近,而总变差损失(total variation loss)则有助于减少合成图像中的噪点。
4.最后,当模型训练结束时,我们输出样式迁移的模型参数,即得到最终的合成图像。
预处理和后处理图像
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
def preprocess(PIL_img, image_shape):
process = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
return process(PIL_img).unsqueeze(dim = 0) # (batch_size, 3, H, W)
def postprocess(img_tensor):
inv_normalize = torchvision.transforms.Normalize(
mean= -rgb_mean / rgb_std,
std= 1/rgb_std)
to_PIL_image = torchvision.transforms.ToPILImage()
return to_PIL_image(inv_normalize(img_tensor[0].cpu()).clamp(0, 1))
抽取特征
# 加载预训练模型
pretrained_net = torchvision.models.vgg19(pretrained=False)
pretrained_net.load_state_dict(torch.load('/home/kesci/input/vgg193427/vgg19-dcbb9e9d.pth'))
net_list = []
for i in range(max(content_layers + style_layers) + 1):
net_list.append(pretrained_net.features[i])
net = torch.nn.Sequential(*net_list) # 提取预训练模型的每一层
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
def get_contents(image_shape, device):
content_X = preprocess(content_img, image_shape).to(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
def get_styles(image_shape, device):
style_X = preprocess(style_img, image_shape).to(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
定义损失函数
1.内容损失
通过平方误差函数衡量合成图像与内容图像在内容特征上的差异
def content_loss(Y_hat, Y):
return F.mse_loss(Y_hat, Y)
2.样式损失
样式损失为两个格拉姆矩阵的平方误差
假设该输出的样本数为1,通道数为
c
c
c,高和宽分别为
h
h
h和
w
w
w,我们可以把输出变换成
c
c
c行
h
w
hw
hw列的矩阵
X
\boldsymbol{X}
X。矩阵
X
\boldsymbol{X}
X可以看作是由
c
c
c个长度为
h
w
hw
hw的向量
x
1
,
…
,
x
c
\boldsymbol{x}_1, \ldots, \boldsymbol{x}_c
x1,…,xc组成的。其中向量
x
i
\boldsymbol{x}_i
xi代表了通道
i
i
i上的样式特征。这些向量的格拉姆矩阵(Gram matrix)
X
X
⊤
∈
R
c
×
c
\boldsymbol{X}\boldsymbol{X}^\top \in \mathbb{R}^{c \times c}
XX⊤∈Rc×c中
i
i
i行
j
j
j列的元素
x
i
j
x_{ij}
xij即向量
i
i
i与
j
j
j的内积,它表达了通道和通道上样式特征的相关性。我们用这样的格拉姆矩阵表达样式层输出的样式。
需要注意的是,当
h
w
hw
hw的值较大时,格拉姆矩阵中的元素容易出现较大的值。此外,格拉姆矩阵的高和宽皆为通道数
c
c
c。为了让样式损失不受这些值的大小影响,下面定义的gram函数将格拉姆矩阵除以了矩阵中元素的个数,即
c
h
w
chw
chw。
def gram(X):
num_channels, n = X.shape[1], X.shape[2] * X.shape[3]
X = X.view(num_channels, n) # 变换为(c,hw)
return torch.matmul(X, X.t()) / (num_channels * n)
def style_loss(Y_hat, gram_Y):
return F.mse_loss(gram(Y_hat), gram_Y)
3.总变差损失
有时候,我们学到的合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。一种常用的降噪方法是总变差降噪(total variation denoising)。假设
x
i
,
j
x_{i,j}
xi,j表示坐标为的像素值
(
i
,
j
)
(i,j)
(i,j),降低总变差损失
∑
i
,
j
∣
x
i
,
j
−
x
i
+
1
,
j
∣
+
∣
x
i
,
j
−
x
i
,
j
+
1
∣
\sum_{i,j} \left|x_{i,j} - x_{i+1,j}\right| + \left|x_{i,j} - x_{i,j+1}\right|
i,j∑∣xi,j−xi+1,j∣+∣xi,j−xi,j+1∣
能够尽可能使邻近的像素值相似。
def tv_loss(Y_hat):
return 0.5 * (F.l1_loss(Y_hat[:, :, 1:, :], Y_hat[:, :, :-1, :]) +
F.l1_loss(Y_hat[:, :, :, 1:], Y_hat[:, :, :, :-1]))
损失函数
样式迁移的损失函数即内容损失、样式损失和总变差损失的加权和。通过调节这些权值超参数,我们可以权衡合成图像在保留内容、迁移样式以及降噪三方面的相对重要性。
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# 分别计算内容损失、样式损失和总变差损失
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# 对所有损失求和
l = sum(styles_l) + sum(contents_l) + tv_l
return contents_l, styles_l, tv_l, l
创建和初始化合成图像
class GeneratedImage(torch.nn.Module):
def __init__(self, img_shape):
super(GeneratedImage, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
def get_inits(X, device, lr, styles_Y):
gen_img = GeneratedImage(X.shape).to(device)
gen_img.weight.data = X.data
optimizer = torch.optim.Adam(gen_img.parameters(), lr=lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, optimizer
训练
image_shape = (150, 225)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
style_X, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.01, 500, 200)