3歩進んで2歩下がる

ゼロつく3/付録A インプレース演算

コピーと上書き

ndarrayインスタンスについて累算代入演算子(+=)を使うとオブジェクトid(アドレス)が上書きされる。


>>> import numpy as np
>>> x = np.array(1)
>>> id(x)
3120427571504 #1

>>> x += x 
>>> id(x)
3120427571504 #2 (#1と同じid➡値だけが上書きされた)
>>>print(x)

>>> x = x + x
>>> id(x)
3120405281840 #3 (#1と異なるid➡コピー)

このようにコピー(複製)を行わずにメモリの値を直接上書きする演算をインプレース演算と呼ぶ。

次のようにxをzに代入しておいて累算代入演算子で加算する場合もインプレース演算となる。ここで問題なのはzとxのアドレスが同じなのでxの値まで変化してしまうことだ。

# 累算代入演算子
>>> x = np.array(1)
>>> z = x
>>> id(x)
3120427316496 #1
>>> id(z)
3120427316496 #2 (#1と同じid)

>>> z += x
>>> id(x)
3120427316496 #3 (#1と同じid)
>>> print(x)
2 # xの値まで変わってしまう!
>>> id(z)
3120427316496 #4 (#2と同じid➡値だけが上書きされた)
>>> print(z)
2

次のように、普通の加算(x+x)を使えば別のアドレスに計算結果がコピーされるためxには影響がない。


# 普通の加算
>>> x = np.array(1)
>>> z = x
>>> id(x)
3120436084240
>>> id(z)
3120436084240

>>> z = x + x
>>> id(x)
3120436084240
>>> print(x)
1 # 1のまま!
>>> id(z)
3120405281808
>>> print(z)
2

DeZeroの逆伝播

DeZeroでは同じ変数を繰り返し使えるように逆伝播の記述(Variable class)に修正を行った。もし微分を初めて設定する場合は出力側から伝わる微分をそのまま代入、それ以降に微分が伝わった場合には加算するようにした。(step14)

ここでは上で見たxをzに代入しておいて”+=”を使う場合の注意点と同じことに気を付けなければいけない。すなわち、下のコードの部分で累算代入演算子を使ってしまうと、gxにまで影響が出てしまう。

例えばy=x+xを考えてみると、Addの逆伝播ではAdd.backward(gy) によってそのままgyが返されてgxに格納される( gxs = f.backward(*gys) の部分 )。そのためgxとgyのidが同じになる。累算代入演算子を使ってgxの値まで変えてしまうとgyも変わってしまう。

funcs = [self.creator]
while funcs:
   f = funcs.pop()
   gys = [output.grad for output in f.outputs]
   gxs = f.backward(*gys) # Addではgyがそのままgxに格納される。
   if not isinstance(gxs, tuple):
      gxs = (gxs, )
   for x, gx in zip(f.inputs, gxs):
      if x.grad is None:
         x.grad = gx
      else :
              x.grad = x.grad + gx
              # x.grad += gx 累算代入演算子だとgxの値まで変化してしまう。

class Add(Function):
    def forward(self, x0, x1):
        y = x0 + x1
        return y
    def backward(self, gy):
        return gy, gy #gyがそのまま返される。

余談

次のように別のオブジェクト(下の例ではy)を累算代入演算子で加える場合はxのアドレスがyのアドレスで上書きされるようなことは無い。


>>> x = np.array(1)
>>> y = np.array(2)
>>>id(x)
3120425906064 #1
>>>id(y)
3120425903856 #2

>>>x += y
>>>id(x)
3120425906064 #3 (#1のまま)
>>>id(y)
3120425903856 #4 (#2のまま)