1. set configuration including essential hyper parameters, you can use easydict to achieve this part
from easydict import EasyDict as edict
conf = edict()
conf.batch_size = xxx
conf.model_path = xxx
conf.logs = xxx
conf.image_path = xxx
.....
2. prepaer for data
in the process , you need firstly define transpose which is equal to preprocessing
from torchvision import transforms as trans
train_transform = trans.Compose([
trans.RandomHorizontalFlip(),
trans.Resize((112,112)),
trans.ToTensor(),
trans.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
using ImageFolder to read data, ImageFolder is a subclass of torch.utils.data.Dataset
if you want to define dateSet by yourself, you have to overload '__getitem__()' and '__len__()'
from torchvision.datasets import ImageFolder
data = ImageFolder(data_path,train_transform)
and then using DataLoader to get a iteration of train dataset ,which can help us to get mini batch by iter
loader = DataLoader(train_data,batch_size,shuffle,num_workers)
3. bulid model
complete a class named xxx ,which is successive from torch.nn.Module.
you must overload forward function.
the class's init function defines layer and forward function defines structure. take an example
from torch.nn import Module
class net(Module):
def __init__(self.*hyperparameters):
super(net,self).__init__()
self.conv1 = torch.nn.Conv2d(xxx)
xxxx
....
def forward(self,x):
out = self.conv1(x)
....
return out
4. set learning rate schedule, global_step, optimizer
optimizer = torch.optim.SGD(model.parameters(),lr = 0.001,momentum = 0.9)
step = 0
milestones = [10,30,50] # each epoch has own lr
if you want to apply different lr to diffent layers, you can replace 'model.parameters()' with a dict
optim.SGD([
{'params': param_1 ,'weight_decay': 4e-4},
{'params': param_2 ,'weight_decay': 4e-5},
{'params': param_3}
],
lr = 0.001,
momentum = 0.9)
if you want to change lr ,you can do this
for params in optimizer.param_groups:
param['lr'] /=10
5. calculate loss function
denote output of model as logits , denote train data label as label
loss = torch.nn.CrossEntropLoss()(logits, label)
6. back Propagation
l
loss.backward()
loss_value = loss.item()
optimizer.step()