numpy 中的 broadcasting(广播)机制

矩阵维度匹配问题,以下举例说明

类型(一)

import numpy as np
x=np.zeros([11,5])
w=np.zeros([5,2])
b=np.zeros([2,1])   
print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))
y=np.dot(x,w)+b
print("y,shape",y.shape)

output:

x.shape (11, 5)
Traceback (most recent call last):
type(x) <class 'numpy.ndarray'>
  File "D:/Python_Code/Tensorflow_Learning_Note/test_code/test.py", line 32, in <module>
w.shape (5, 2)
    y=np.dot(x,w)+b
type(w) <class 'numpy.ndarray'>
ValueError: operands could not be broadcast together with shapes (11,2) (2,1) 
b.shape (2, 1)
type(b) <class 'numpy.ndarray'>

矩阵维度匹配:
np.dot是矩阵乘法,xw=(3,5)(5,2)=(3,2) , b=(2,1)
所以 x*w+b无法相加

类型(二)

import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([7,1])   # !!!!!!!注意此处的维度变换

print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))

y=np.dot(x,w)+b
print("y,shape",y.shape)

print(b)
print(b.T)

output:

x.shape (7, 5)
type(x) <class 'numpy.ndarray'>
w.shape (5, 2)
type(w) <class 'numpy.ndarray'>
b.shape (7, 1)
type(b) <class 'numpy.ndarray'>
y,shape (7, 2)
[[0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]
[[0. 0. 0. 0. 0. 0. 0.]]



矩阵维度以广播机制匹配:

类型(三)

import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([1,2])   # !!!!!!!注意此处的维度变换

print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))

y=np.dot(x,w)+b
print("y,shape",y.shape)

print(b)
print(b.T)


output:

x.shape (7, 5)
type(x) <class 'numpy.ndarray'>
w.shape (5, 2)
type(w) <class 'numpy.ndarray'>
b.shape (1, 2)
type(b) <class 'numpy.ndarray'>
y,shape (7, 2)
[[0. 0.]]
[[0.]
 [0.]]




矩阵维度由于广播机制而匹配。

类型(四)

import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([2])   # !!!!!!!注意此处的维度变换

print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))

y=np.dot(x,w)+b
print("y,shape",y.shape)

print(b)
print(b.T)

output:

x.shape (7, 5)
type(x) <class 'numpy.ndarray'>
w.shape (5, 2)
type(w) <class 'numpy.ndarray'>
b.shape (2,)
type(b) <class 'numpy.ndarray'>
y,shape (7, 2)
[0. 0.]
[0. 0.]

注意!!!

此时的b的shape是(2, ),既不是行向量,也不是列向量。这样虽然可以广播机制成功,但是容易出现难以调试的bug。

吴恩达深度学习视频中明确指出要用np.zeros([2,1])代替 np.zeros([2])。

类型(五)

import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([7])   # !!!!!!!注意此处的维度变换

print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))

y=np.dot(x,w)+b
print("y,shape",y.shape)

print(b)
print(b.T)

output:

Traceback (most recent call last):
x.shape (7, 5)
  File "D:/Python_Code/Tensorflow_Learning_Note/test_code/test.py", line 34, in <module>
type(x) <class 'numpy.ndarray'>
    y=np.dot(x,w)+b
w.shape (5, 2)
ValueError: operands could not be broadcast together with shapes (7,2) (7,) 
type(w) <class 'numpy.ndarray'>
b.shape (7,)
type(b) <class 'numpy.ndarray'>

注意!!!

此时的b的shape是(7, ),既不是行向量,也不是列向量。貌似跟上一种情况类似,但是却不能广播。说明这种奇怪的矩阵只有当列的维度匹配时才能广播,正如上一种情况。

所以,日后尽量不用这种 类似(7,)的shape的矩阵。