参考文献
https://coderzcolumn.com/tutorials/artificial-intelligence/pytorch-lightning-eliminate-training-loops
import pytorch_lightning as pl
print("PyTorch Lightning Version : {}".format(pl.__version__))
import torch
print("PyTorch Version : {}".format(torch.__version__))
from sklearn import datasets
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
class DigitsDataset(Dataset):
def __init__(self, train_or_test="train", feat_transform=torch.tensor, target_transform=torch.tensor):
self.typ = train_or_test
X, Y = datasets.load_digits(return_X_y=True)
self.X_train, self.X_test, self.Y_train, self.Y_test = train_test_split(X, Y,
train_size=0.8,
stratify=Y,
random_state=123)
self.feat_transform = feat_transform
self.target_transform = target_transform
def __len__(self):
return len(self.Y_train) if self.typ == "train" else len(self.Y_test)
def __getitem__(self, idx):
if self.typ == "train":
x, y = self.X_train[idx], self.Y_train[idx]
else:
x, y = self.X_test[idx], self.Y_test[idx]
return self.feat_transform(x), self.target_transform(y)
train_dataset = DigitsDataset("train")
test_dataset = DigitsDataset("test")
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
for X_batch, Y_batch in train_loader:
print(X_batch.shape, Y_batch.shape)
break
for X_batch, Y_batch in test_loader:
print(X_batch.shape, Y_batch.shape)
break
from torch import nn
from torch.optim import Adam
class DigitsClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(64,16),
nn.ReLU(),
nn.Linear(16,32),
nn.ReLU(),
nn.Linear(32,10),
nn.Softmax(dim=-1),
)
def forward(self, X_batch):
preds = self.model(X_batch)
return preds
classifier = DigitsClassifier()
classifier
preds = classifier(torch.rand(50,64))
preds.shape
完整例子
from torch import nn
from torch.optim import Adam
class DigitsClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(64,16),
nn.ReLU(),
nn.Linear(16,32),
nn.ReLU(),
nn.Linear(32,10),
nn.Softmax(dim=-1),
)
self.crossentropy_loss = nn.CrossEntropyLoss()
def forward(self, X_batch):
preds = self.model(X_batch)
return preds
def training_step(self, batch, batch_idx):
X_batch, Y_batch = batch
preds = self.model(X_batch.float())
loss_val = self.crossentropy_loss(preds, Y_batch.long())
self.log("Train Loss : ", loss_val)
return loss_val
def validation_step(self, batch, batch_idx):
X_batch, Y_batch = batch
preds = self.model(X_batch.float())
loss_val = self.crossentropy_loss(preds, Y_batch.long())
self.log("Validation Loss : ", loss_val)
return loss_val
def test_step(self, batch, batch_idx):
X_batch, Y_batch = batch
preds = self.model(X_batch.float())
loss_val = self.crossentropy_loss(preds, Y_batch.long())
self.log("Test Loss : ", loss_val)
return loss_val
def predict_step(self, batch, batch_idx,dataloader_idx=0):
X_batch, Y_batch = batch
preds = self.model(X_batch.float())
return preds
def configure_optimizers(self):
optimizer = Adam(self.model.parameters(), lr=1e-3)
return optimizer
train_dataset = DigitsDataset("train")
test_dataset = DigitsDataset("test")
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=0)
classifier = DigitsClassifier()
#pl.seed_everything(42, workers=True)
trainer = pl.Trainer(max_epochs=30, log_every_n_steps=20) #, deterministic=True)
trainer.fit(classifier, train_loader, test_loader)
trainer.validate(classifier, test_loader)
trainer.test(classifier, test_loader)
preds = trainer.predict(classifier, test_loader)
preds = torch.concat(preds)
preds = preds.argmax(axis=1)
preds[:5]
Y_test = []
for x,y in test_loader:
Y_test.append(y)
Y_test = torch.concat(Y_test)
Y_test[:5]
from sklearn.metrics import accuracy_score
print("Test Accuracy : {:.3f}".format(accuracy_score(preds, Y_test)))
from sklearn.metrics import classification_report
print("Classification Report : ")
print(classification_report(preds, Y_test))