Autoencoderの実装.
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn as nn
from torch.autograd import Variable
from PIL import Image
import torch.utils.data as Data
# Mnist digits dataset
DOWNLOAD_MNIST = True
BATCH_SIZE = 100
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# convert test data into Variable, pick 2000 samples to speed up testing
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc = nn.Sequential(
nn.Linear(784, 1000),
nn.BatchNorm1d(1000),
nn.ReLU(),
nn.Linear(1000, 500),
nn.BatchNorm1d(500),
nn.ReLU(),
nn.Linear(500, 250),
nn.BatchNorm1d(250),
nn.ReLU(),
nn.Linear(250, 2)
)
def forward(self, x):
return self.fc(x)
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.fc = nn.Sequential(
nn.Linear(2, 250),
nn.BatchNorm1d(250),
nn.ReLU(),
nn.Linear(250, 500),
nn.BatchNorm1d(500),
nn.ReLU(),
nn.Linear(500, 1000),
nn.BatchNorm1d(1000),
nn.ReLU(),
nn.Linear(1000, 784)
)
def forward(self, x):
return self.fc(x)
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.fc1 = Encoder()
self.fc2 = Decoder()
def forward(self, x):
code = self.fc1(x)
out = self.fc2(code)
return out, code
ae = AutoEncoder().cuda()
print(ae)
optimizer = torch.optim.Adam(ae.parameters(), 0.001)
loss_fn = nn.MSELoss()
plt.clf()
test_losses = []
x = test_x.view(-1, 28*28).cuda()
ar_raw = 255 * x.cpu()[0].view(28, 28).data.numpy()
fig = plt.figure(figsize=[12, 6])
ax1 = fig.add_subplot(1, 4, 1)
ax2 = fig.add_subplot(1, 4, 2)
ax3 = fig.add_subplot(1, 4, 3)
ax4 = fig.add_subplot(1, 4, 4)
ax1.imshow(255 * x.cpu()[0].view(28, 28).data.numpy(), cmap='gray')
ax2.imshow(255 * x.cpu()[1].view(28, 28).data.numpy(), cmap='gray')
ax3.imshow(255 * x.cpu()[2].view(28, 28).data.numpy(), cmap='gray')
ax4.imshow(255 * x.cpu()[3].view(28, 28).data.numpy(), cmap='gray')
plt.title('original')
plt.show()
ls = [0, 1, 2, 3, 5, 10, 20]
for epoch in range(201):
if True:
ae.eval()
x = test_x.view(-1, 28*28).cuda()
out, code = ae(x)
loss = loss_fn(out, x)
test_losses.append(loss.data[0])
if epoch in ls:
fig = plt.figure(figsize=[12, 6])
ax1 = fig.add_subplot(1, 4, 1)
ax2 = fig.add_subplot(1, 4, 2)
ax3 = fig.add_subplot(1, 4, 3)
ax4 = fig.add_subplot(1, 4, 4)
ax1.imshow(255 * out.cpu()[0].view(28, 28).data.numpy(), cmap='gray')
ax2.imshow(255 * out.cpu()[1].view(28, 28).data.numpy(), cmap='gray')
ax3.imshow(255 * out.cpu()[2].view(28, 28).data.numpy(), cmap='gray')
ax4.imshow(255 * out.cpu()[3].view(28, 28).data.numpy(), cmap='gray')
plt.title("epoch: {}, loss: {}".format(epoch, loss.data[0]))
plt.plot()
plt.show()
ae.train()
for step, (x, y) in enumerate(train_loader):
x = Variable(x.view(-1, 28*28)).cuda()
out, code = ae(x)
loss = loss_fn(out, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
plt.plot(test_losses)
plt.title('test_losses')
plt.show()
->
AutoEncoder (
(fc1): Encoder (
(fc): Sequential (
(0): Linear (784 -> 1000)
(1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True)
(2): ReLU ()
(3): Linear (1000 -> 500)
(4): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True)
(5): ReLU ()
(6): Linear (500 -> 250)
(7): BatchNorm1d(250, eps=1e-05, momentum=0.1, affine=True)
(8): ReLU ()
(9): Linear (250 -> 2)
)
)
(fc2): Decoder (
(fc): Sequential (
(0): Linear (2 -> 250)
(1): BatchNorm1d(250, eps=1e-05, momentum=0.1, affine=True)
(2): ReLU ()
(3): Linear (250 -> 500)
(4): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True)
(5): ReLU ()
(6): Linear (500 -> 1000)
(7): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True)
(8): ReLU ()
(9): Linear (1000 -> 784)
)
)
)