3歩進んで2歩下がる

ゼロつく3/ステップ9 0次元ndarrayについて

0次元のndarrayの扱いについて注意すべき点がある。すぐに忘れてしまいそうなのでここに書いておく。

演算結果がゼロ次元の場合、スカラーに変換される

まず次のコードを見てみる。


x = np.array([1.0])
y = x**2
print(type(x), x.ndim)
print(type(y), y.ndim)
実行結果
<class 'numpy.ndarray'> 1
<class 'numpy.ndarray'> 1

この場合は変なことは起きていない。

問題は次の場合だ。


x = np.array(1.0)
y = x**2
print(type(x), x.ndim)
print(type(y), y.ndim)
実行結果
<class 'numpy.ndarray'> 0
<class 'numpy.float64'> 0

xは0次元のndarray型だがx**2の結果はfloat64型になってしまっている。これはnumpyの仕様らしい。

変数の型をndarrayに統一して書いていたはずなのにいつの間にかfloat型になってるやん!という事が起きないよう以下のようにあらかじめスカラーを検知してndarrayに変換するようにコードを書いておくとエラーを防げる。

#要所を抽出して簡略化して記述しています。

def as_array(x):
   if np.isscalar(x):
      return np.array(x)
   return x

output = Variable(as_array(y))