ゼロつく3/ステップ9 0次元ndarrayについて
0次元のndarrayの扱いについて注意すべき点がある。すぐに忘れてしまいそうなのでここに書いておく。
演算結果がゼロ次元の場合、スカラーに変換される
まず次のコードを見てみる。
x = np.array([1.0])
y = x**2
print(type(x), x.ndim)
print(type(y), y.ndim)
実行結果
<class 'numpy.ndarray'> 1
<class 'numpy.ndarray'> 1
この場合は変なことは起きていない。
問題は次の場合だ。
x = np.array(1.0)
y = x**2
print(type(x), x.ndim)
print(type(y), y.ndim)
実行結果
<class 'numpy.ndarray'> 0
<class 'numpy.float64'> 0
xは0次元のndarray型だがx**2の結果はfloat64型になってしまっている。これはnumpyの仕様らしい。
変数の型をndarrayに統一して書いていたはずなのにいつの間にかfloat型になってるやん!という事が起きないよう以下のようにあらかじめスカラーを検知してndarrayに変換するようにコードを書いておくとエラーを防げる。
#要所を抽出して簡略化して記述しています。
def as_array(x):
if np.isscalar(x):
return np.array(x)
return x
output = Variable(as_array(y))