
PyTorch練習 03日目

MorvanZhou/PyTorch-Tutorialを参考にMNISTの数字判別MLPを組んだ (元ネタはCNN).

import torch 
import torchvision
import torch.nn as nn
import numpy as np
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F

LR = 0.001

transforms = torchvision.transforms.Compose([
train_data = torchvision.datasets.MNIST(
    root='./', train=True, download=True, transform=transforms)

train_loader = torch.utils.data.DataLoader(
    dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root='./', train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)
                 .view(-1, 28*28), volatile=True).type(torch.FloatTensor)[:2000]/255
test_y = test_data.test_labels[:2000]

mlp = torch.nn.Sequential(
            torch.nn.Linear(28*28, 100),
            torch.nn.Linear(100, 30),
            torch.nn.Linear(30, 10)


optimizer = torch.optim.Adam(mlp.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):
        b_x = Variable(x.view(50, 28*28))
        b_y = Variable(y)

        output = mlp(b_x)

        loss = loss_fn(output, b_y)

        if step % 100 == 0:
            test_output = mlp(test_x)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = sum(pred_y == test_y) / float(test_y.size(0))
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0],
                 '| test accuracy: %.2f' % accuracy)
  (layer1): Sequential (
    (0): Linear (784 -> 100)
    (1): ReLU ()
    (2): Linear (100 -> 30)
    (3): ReLU ()
    (4): Linear (30 -> 10)
Epoch:  0 | train loss: 2.3134 | test accuracy: 0.09
Epoch:  0 | train loss: 0.7347 | test accuracy: 0.81
Epoch:  0 | train loss: 0.4593 | test accuracy: 0.85
Epoch:  0 | train loss: 0.4268 | test accuracy: 0.88
Epoch:  0 | train loss: 0.1728 | test accuracy: 0.89
Epoch:  0 | train loss: 0.2370 | test accuracy: 0.89
Epoch:  0 | train loss: 0.0994 | test accuracy: 0.90
Epoch:  0 | train loss: 0.2626 | test accuracy: 0.90
Epoch:  0 | train loss: 0.1483 | test accuracy: 0.91
Epoch:  0 | train loss: 0.2041 | test accuracy: 0.92
Epoch:  0 | train loss: 0.1486 | test accuracy: 0.91
Epoch:  0 | train loss: 0.2538 | test accuracy: 0.92
Epoch:  1 | train loss: 0.0768 | test accuracy: 0.93
Epoch:  1 | train loss: 0.1138 | test accuracy: 0.93
Epoch:  1 | train loss: 0.1675 | test accuracy: 0.93
Epoch:  1 | train loss: 0.0724 | test accuracy: 0.93
Epoch:  1 | train loss: 0.0983 | test accuracy: 0.94
Epoch:  1 | train loss: 0.1681 | test accuracy: 0.93
Epoch:  1 | train loss: 0.1569 | test accuracy: 0.94
Epoch:  1 | train loss: 0.2666 | test accuracy: 0.94
Epoch:  1 | train loss: 0.1030 | test accuracy: 0.93
Epoch:  1 | train loss: 0.1784 | test accuracy: 0.93
Epoch:  1 | train loss: 0.2013 | test accuracy: 0.95
Epoch:  1 | train loss: 0.1681 | test accuracy: 0.95
Epoch:  2 | train loss: 0.0487 | test accuracy: 0.95
Epoch:  2 | train loss: 0.0959 | test accuracy: 0.95
Epoch:  2 | train loss: 0.1366 | test accuracy: 0.95
Epoch:  2 | train loss: 0.1528 | test accuracy: 0.95
Epoch:  2 | train loss: 0.0860 | test accuracy: 0.95
Epoch:  2 | train loss: 0.0218 | test accuracy: 0.95
Epoch:  2 | train loss: 0.1122 | test accuracy: 0.95
Epoch:  2 | train loss: 0.1109 | test accuracy: 0.96
Epoch:  2 | train loss: 0.0879 | test accuracy: 0.96
Epoch:  2 | train loss: 0.1182 | test accuracy: 0.96
Epoch:  2 | train loss: 0.0585 | test accuracy: 0.96
Epoch:  2 | train loss: 0.0579 | test accuracy: 0.95


A = torch.arange(0, 5* 4 * 3 * 2) # 長さ120の一次元Tensor
A.view(5, 4, 3, 2)      # size (5, 4, 3, 2) の4次元Tensor
A.view(5, -1)           # size (5, 24) の二次元Tensor
                        # -1を引数にすると,他の次元を勘案してその次元のsizeが決まる.
A.view(7, -1)           
# RuntimeError: invalid argument 2: size '[7 x -1]' is invalid for input of with 120 elements at /pytorch/torch/lib/TH/THStorage.c:37

