为了验证VIT网络对时空预测的能力,编写代码对其验证。
实验方案1.
input=i
target=i
的目是让VIT网络什么都不做,输出结果就是自身。
结果损失越来越大
代码如下:
import torch
from torch import nn
from tqdm import trange
import torch.optim as optim
from vit_model import vit_custom_in21k as create_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=create_model().to(device)
print('use :', device)
loss_fn=torch.nn.L1Loss()
loss_fn = loss_fn.to(device)
learning_gate = 0.001
optim = torch.optim.Adam(model.parameters(), lr=learning_gate)
for e in range(1,100):
with trange(-100000,100000,100) as t:
for i in t:
optim.zero_grad()
input1 = torch.tensor([[[ [j]]] for j in range(i,i+100)]).float()
target=input1
target=target.view(100,1)
outputs=model(input1.to(device))
loss=loss_fn(outputs,target.to(device))
loss.backward()
optim.step()
t.set_postfix(loss=loss.item(),input1=input1[0].item(),target=target[0].item(),outputs=outputs[0].item())
torch.save(model,'test_model_8m_8d.pht')
输出情况如下图:
实验方案2
采用时空预测方案。
输入为(1,10,10,10),其中每张10*10张量的数值一样,从i到i+10堆叠
标签为是个数值,从i+5到i+15
损失函数没有一点下降,越来越大。
代码如下:
import torch
from vit_model import VisionTransformer
def vit_custom_in21k(num_classes: int = 10, has_logits: bool = True):
"""
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
"""
model = VisionTransformer(img_size=10,
patch_size=1,
in_c=10,
embed_dim=768,
depth=12,
num_heads=1,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model
def train1():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=vit_custom_in21k().to(device)
print('use :', device)
loss_fn=torch.nn.L1Loss()
loss_fn = loss_fn.to(device)
learning_gate = 0.001
optim = torch.optim.Adam(model.parameters(), lr=learning_gate)
for i in range(0,10000):
optim.zero_grad()
data = list(torch.full([10, 10], j) for j in range(i, i + 10))
input1 = torch.stack(data).unsqueeze(0).float()
target = torch.tensor([j for j in range(i + 5, i + 15)]).float()
outputs=model(input1.to(device))
target = target.view(outputs.shape)
loss=loss_fn(outputs,target.to(device))
loss.backward()
optim.step()
print('\r',target[0][0].item(),'outputs:',outputs[0][0].item(),'loss',loss.item(),end='')
torch.save(model,'test_model_8m_8d.pht')
if __name__=='__main__':
train1()
实验结果如下:
结论
VIT(vision Transformer)对回归预测能力较弱
至少从实验结果而言,对于数据较为简单的情况预测能力较弱。