import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
from tf_utils import load_dataset
X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()
X_train_flatten = X_train_orig.reshape(X_train_orig.shape[0], -1).T / 255
X_test_flatten = X_test_orig.reshape(X_test_orig.shape[0], -1).T / 255
X_train = torch.from_numpy(X_train_flatten).to(torch.float)
Y_train = torch.from_numpy(Y_train_orig)
X_test = torch.from_numpy(X_test_flatten).to(torch.float)
Y_test = torch.from_numpy(Y_test_orig)
class MyDataset(Dataset):
def __init__(self, X, Y):
self.X = X
self.Y = Y
def __len__(self):
return len(self.Y[0])
def __getitem__(self, idx):
x = self.X[:,idx]
y = self.Y[:,idx]
return x, y
training_data = MyDataset(X_train,Y_train)
test_data = MyDataset(X_test,Y_test)
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)
class Model(nn.Module):
def __init__(self, N_in, h1, h2, N_out):
super(Model, self).__init__()
self.linear_relu_stack = nn.Sequential(
nn.Linear(N_in, h1),
nn.ReLU(),
nn.Linear(h1, h2),
nn.ReLU(),
nn.Linear(h2, N_out),
)
def forward(self, x):
logits = self.linear_relu_stack(x)
return logits
N_in, h1, h2, N_out = 12288, 32, 12, 6
model = Model(N_in, h1, h2, N_out)
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y.squeeze(dim=1))
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y.squeeze(dim=1)).item()
correct += (pred.argmax(1) == y.squeeze(dim=1)).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 1000
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")