Loading [MathJax]/jax/output/HTML-CSS/jax.js

2017年12月4日月曜日

Loss functionのあれこれ

semantic segmentationのloss functionで完全に迷子になったので復習.

回帰用

  1. torch.nn.L1Loss()
    loss(x,y)=1nni=1|xiyi|
    これで学習すると重みがsparseになりやすいことが知られている.
  2. torch.nn.MSELoss
    loss(x,y)=1ni|xiyi|2

分類用

C個のクラスに分類する.

  1. torch.nn.CrossEntropyLoss
    NLLとsoftmaxを合成したloss.
    minibatch Xとして,その元xC次元のベクトルで,i成分がクラスi{0,...,C1}に分類されるスコアであるとする.スコアはsoftmaxによって確率に変換される.
    classxが分類されるべきクラス,pclassxclassに分類される確率とすると,
    loss(x,class)=logexp(x[class])jexp(x[j])=logpclass=x[class]+log(jexp(x[j]))
    pytorchでは,入力は
    input: (N,C)のtensor. j行はxj.
    target: (N)のtensor, (x1のクラス,...,xNのクラス)

  2. torch.nn.NLLLoss
    Cross Entropy Lossとほとんど同じ. softmaxを噛ませるか噛ませないか.
    loss(x,class)=x[class]

  3. torch.nn.PoissonNLLLoss

  4. torch.nn.NLLLoss2d
    NLLLossの画像版で,inputのピクセルごとにNLLLossを計算する.
    input: (N,C,H,W)のtensor. とりあえずmini-batchの次元Nは無視するとして,cCi,j成分に対応する要素をxci,jとすると,
    xci,jがinputの(i,j)ピクセルがクラスcに属するスコアであって,class(ij)をそれが属すべき真のクラスとすると
    loss(x,class)=i,jxclass(ij)
    targetは(N,H,W)というtensorで,(n,h,w)成分は,
    targetn,h,w={1    (n=class(hw))0(otherwise)

  5. torch.nn.KLDivLoss
    KL-divergenceによるloss. inputは確率分布だから,総和は1になる.
    loss(x,target)=1n(targeti(log(targeti)xi))

  6. torch.nn.BCELoss, binary cross entropy criterion
    loss(o,t)=1ni{t[i]log(o[i])+(1t[i]log(1o[i])})
    不安定なので,BCEWithLogitsLossが提案されている.

  7. BCEWithLogitsLoss
    loss(o,t)=1ni{t[i]log(sigmoid(o[i]))+(1t[i])log(1sigmoid(o[i]))}
    auto-encoderに使われるらしい. 0t[i]1が必ず成立するようにする.

Semantic segmentationでは複雑なloss functionを自分で書いて実装することになる・・・