2017年11月27日月曜日

DataLoaderの使い方

DataLoaderにはDataset型を食わせればいい
例:Kaixhin/FCN-semantic-segmentation/data.py

import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset


# Labels: -1 license plate, 0 unlabeled, 1 ego vehicle, 2 rectification border, 3 out of roi, 4 static, 5 dynamic, 6 ground, 7 road, 8 sidewalk, 9 parking, 10 rail track, 11 building, 12 wall, 13 fence, 14 guard rail, 15 bridge, 16 tunnel, 17 pole, 18 polegroup, 19 traffic light, 20 traffic sign, 21 vegetation, 22 terrain, 23 sky, 24 person, 25 rider, 26 car, 27 truck, 28 bus, 29 caravan, 30 trailer, 31 train, 32 motorcycle, 33 bicycle
num_classes = 20
full_to_train = {-1: 19, 0: 19, 1: 19, 2: 19, 3: 19, 4: 19, 5: 19, 6: 19, 7: 0, 8: 1, 9: 19, 10: 19, 11: 2, 12: 3, 13: 4, 14: 19, 15: 19, 16: 19, 17: 5, 18: 19, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: 19, 30: 19, 31: 16, 32: 17, 33: 18}
train_to_full = {0: 7, 1: 8, 2: 11, 3: 12, 4: 13, 5: 17, 6: 19, 7: 20, 8: 21, 9: 22, 10: 23, 11: 24, 12: 25, 13: 26, 14: 27, 15: 28, 16: 31, 17: 32, 18: 33, 19: 0}
full_to_colour = {0: (0, 0, 0), 7: (128, 64, 128), 8: (244, 35, 232), 11: (70, 70, 70), 12: (102, 102, 156), 13: (190, 153, 153), 17: (153, 153, 153), 19: (250, 170, 30), 20: (220, 220, 0), 21: (107, 142, 35), 22: (152, 251, 152), 23: (70, 130, 180), 24: (220, 20, 60), 25: (255, 0, 0), 26: (0, 0, 142), 27: (0, 0, 70), 28: (0, 60,100), 31: (0, 80, 100), 32: (0, 0, 230), 33: (119, 11, 32)}


class CityscapesDataset(Dataset):
  def __init__(self, split='train', crop=None, flip=False):
    super().__init__()
    self.crop = crop
    self.flip = flip
    self.inputs = []
    self.targets = []

    for root, _, filenames in os.walk(os.path.join('leftImg8bit_trainvaltest', 'leftImg8bit', split)):
      for filename in filenames:
        if os.path.splitext(filename)[1] == '.png':
          filename_base = '_'.join(filename.split('_')[:-1])
          target_root = os.path.join('gtFine_trainvaltest', 'gtFine', split, os.path.basename(root))
          self.inputs.append(os.path.join(root, filename_base + '_leftImg8bit.png'))
          self.targets.append(os.path.join(target_root, filename_base + '_gtFine_labelIds.png'))

  def __len__(self):
    return len(self.inputs)

  def __getitem__(self, i):
    # Load images and perform augmentations with PIL
    input, target = Image.open(self.inputs[i]), Image.open(self.targets[i])
    # Random uniform crop
    if self.crop is not None:
      w, h = input.size
      x1, y1 = random.randint(0, w - self.crop), random.randint(0, h - self.crop)
      input, target = input.crop((x1, y1, x1 + self.crop, y1 + self.crop)), target.crop((x1, y1, x1 + self.crop, y1 + self.crop))
    # Random horizontal flip
    if self.flip:
      if random.random() < 0.5:
        input, target = input.transpose(Image.FLIP_LEFT_RIGHT), target.transpose(Image.FLIP_LEFT_RIGHT)

    # Convert to tensors
    w, h = input.size
    input = torch.ByteTensor(torch.ByteStorage.from_buffer(input.tobytes())).view(h, w, 3).permute(2, 0, 1).float().div(255)
    target = torch.ByteTensor(torch.ByteStorage.from_buffer(target.tobytes())).view(h, w).long()
    # Normalise input
    input[0].add_(-0.485).div_(0.229)
    input[1].add_(-0.456).div_(0.224)
    input[2].add_(-0.406).div_(0.225)
    # Convert to training labels
    remapped_target = target.clone()
    for k, v in full_to_train.items():
      remapped_target[target == k] = v
    # Create one-hot encoding
    target = torch.zeros(num_classes, h, w)
    for c in range(num_classes):
      target[c][remapped_target == c] = 1
return input, target, remapped_target # Return x, y (one-hot), y (index)

0 件のコメント:

コメントを投稿