MorvanZhou/PyTorch-Tutorialを参考にMNISTの数字判別MLPを組んだ (元ネタはCNN).
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
EPOCH = 3
BATCH_SIZE = 50
LR = 0.001
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_data = torchvision.datasets.MNIST(
root='./', train=True, download=True, transform=transforms)
train_loader = torch.utils.data.DataLoader(
dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(root='./', train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)
.view(-1, 28*28), volatile=True).type(torch.FloatTensor)[:2000]/255
test_y = test_data.test_labels[:2000]
mlp = torch.nn.Sequential(
torch.nn.Linear(28*28, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 30),
torch.nn.ReLU(),
torch.nn.Linear(30, 10)
)
print(mlp)
optimizer = torch.optim.Adam(mlp.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader):
b_x = Variable(x.view(50, 28*28))
b_y = Variable(y)
output = mlp(b_x)
loss = loss_fn(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
test_output = mlp(test_x)
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / float(test_y.size(0))
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0],
'| test accuracy: %.2f' % accuracy)
MLP (
(layer1): Sequential (
(0): Linear (784 -> 100)
(1): ReLU ()
(2): Linear (100 -> 30)
(3): ReLU ()
(4): Linear (30 -> 10)
)
)
Epoch: 0 | train loss: 2.3134 | test accuracy: 0.09
Epoch: 0 | train loss: 0.7347 | test accuracy: 0.81
Epoch: 0 | train loss: 0.4593 | test accuracy: 0.85
Epoch: 0 | train loss: 0.4268 | test accuracy: 0.88
Epoch: 0 | train loss: 0.1728 | test accuracy: 0.89
Epoch: 0 | train loss: 0.2370 | test accuracy: 0.89
Epoch: 0 | train loss: 0.0994 | test accuracy: 0.90
Epoch: 0 | train loss: 0.2626 | test accuracy: 0.90
Epoch: 0 | train loss: 0.1483 | test accuracy: 0.91
Epoch: 0 | train loss: 0.2041 | test accuracy: 0.92
Epoch: 0 | train loss: 0.1486 | test accuracy: 0.91
Epoch: 0 | train loss: 0.2538 | test accuracy: 0.92
Epoch: 1 | train loss: 0.0768 | test accuracy: 0.93
Epoch: 1 | train loss: 0.1138 | test accuracy: 0.93
Epoch: 1 | train loss: 0.1675 | test accuracy: 0.93
Epoch: 1 | train loss: 0.0724 | test accuracy: 0.93
Epoch: 1 | train loss: 0.0983 | test accuracy: 0.94
Epoch: 1 | train loss: 0.1681 | test accuracy: 0.93
Epoch: 1 | train loss: 0.1569 | test accuracy: 0.94
Epoch: 1 | train loss: 0.2666 | test accuracy: 0.94
Epoch: 1 | train loss: 0.1030 | test accuracy: 0.93
Epoch: 1 | train loss: 0.1784 | test accuracy: 0.93
Epoch: 1 | train loss: 0.2013 | test accuracy: 0.95
Epoch: 1 | train loss: 0.1681 | test accuracy: 0.95
Epoch: 2 | train loss: 0.0487 | test accuracy: 0.95
Epoch: 2 | train loss: 0.0959 | test accuracy: 0.95
Epoch: 2 | train loss: 0.1366 | test accuracy: 0.95
Epoch: 2 | train loss: 0.1528 | test accuracy: 0.95
Epoch: 2 | train loss: 0.0860 | test accuracy: 0.95
Epoch: 2 | train loss: 0.0218 | test accuracy: 0.95
Epoch: 2 | train loss: 0.1122 | test accuracy: 0.95
Epoch: 2 | train loss: 0.1109 | test accuracy: 0.96
Epoch: 2 | train loss: 0.0879 | test accuracy: 0.96
Epoch: 2 | train loss: 0.1182 | test accuracy: 0.96
Epoch: 2 | train loss: 0.0585 | test accuracy: 0.96
Epoch: 2 | train loss: 0.0579 | test accuracy: 0.95
DataLoaderの使い方がまだわかっていないので引き続き練習する.
また,numpy.Array.reshapeと似た機能をもつまた,Torch.Tensor.viewについていくつかメモしておく.
この変換はtorchvision.datasets.MNISTのtransformの段階で行ったほうが効率がいいと思うが,どうやればいいのかわからなかった.
A = torch.arange(0, 5* 4 * 3 * 2) # 長さ120の一次元Tensor
A.view(5, 4, 3, 2) # size (5, 4, 3, 2) の4次元Tensor
A.view(5, -1) # size (5, 24) の二次元Tensor
# -1を引数にすると,他の次元を勘案してその次元のsizeが決まる.
A.view(7, -1)
# RuntimeError: invalid argument 2: size '[7 x -1]' is invalid for input of with 120 elements at /pytorch/torch/lib/TH/THStorage.c:37
0 件のコメント:
コメントを投稿