numpy 中的 broadcasting(广播)机制
矩阵维度匹配问题,以下举例说明
类型(一)
import numpy as np
x=np.zeros([11,5])
w=np.zeros([5,2])
b=np.zeros([2,1])
print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))
y=np.dot(x,w)+b
print("y,shape",y.shape)
output:
x.shape (11, 5)
Traceback (most recent call last):
type(x) <class 'numpy.ndarray'>
File "D:/Python_Code/Tensorflow_Learning_Note/test_code/test.py", line 32, in <module>
w.shape (5, 2)
y=np.dot(x,w)+b
type(w) <class 'numpy.ndarray'>
ValueError: operands could not be broadcast together with shapes (11,2) (2,1)
b.shape (2, 1)
type(b) <class 'numpy.ndarray'>
矩阵维度匹配:
np.dot是矩阵乘法,xw=(3,5)(5,2)=(3,2) , b=(2,1)
所以 x*w+b无法相加
类型(二)
import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([7,1]) # !!!!!!!注意此处的维度变换
print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))
y=np.dot(x,w)+b
print("y,shape",y.shape)
print(b)
print(b.T)
output:
x.shape (7, 5)
type(x) <class 'numpy.ndarray'>
w.shape (5, 2)
type(w) <class 'numpy.ndarray'>
b.shape (7, 1)
type(b) <class 'numpy.ndarray'>
y,shape (7, 2)
[[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
[0.]]
[[0. 0. 0. 0. 0. 0. 0.]]
矩阵维度以广播机制匹配:
类型(三)
import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([1,2]) # !!!!!!!注意此处的维度变换
print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))
y=np.dot(x,w)+b
print("y,shape",y.shape)
print(b)
print(b.T)
output:
x.shape (7, 5)
type(x) <class 'numpy.ndarray'>
w.shape (5, 2)
type(w) <class 'numpy.ndarray'>
b.shape (1, 2)
type(b) <class 'numpy.ndarray'>
y,shape (7, 2)
[[0. 0.]]
[[0.]
[0.]]
矩阵维度由于广播机制而匹配。
类型(四)
import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([2]) # !!!!!!!注意此处的维度变换
print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))
y=np.dot(x,w)+b
print("y,shape",y.shape)
print(b)
print(b.T)
output:
x.shape (7, 5)
type(x) <class 'numpy.ndarray'>
w.shape (5, 2)
type(w) <class 'numpy.ndarray'>
b.shape (2,)
type(b) <class 'numpy.ndarray'>
y,shape (7, 2)
[0. 0.]
[0. 0.]
注意!!!
此时的b的shape是(2, ),既不是行向量,也不是列向量。这样虽然可以广播机制成功,但是容易出现难以调试的bug。
吴恩达深度学习视频中明确指出要用np.zeros([2,1])代替 np.zeros([2])。
类型(五)
import numpy as np
x=np.zeros([7,5])
w=np.zeros([5,2])
b=np.zeros([7]) # !!!!!!!注意此处的维度变换
print("x.shape",x.shape)
print("type(x)",type(x))
print("w.shape",w.shape)
print("type(w)",type(w))
print("b.shape",b.shape)
print("type(b)",type(b))
y=np.dot(x,w)+b
print("y,shape",y.shape)
print(b)
print(b.T)
output:
Traceback (most recent call last):
x.shape (7, 5)
File "D:/Python_Code/Tensorflow_Learning_Note/test_code/test.py", line 34, in <module>
type(x) <class 'numpy.ndarray'>
y=np.dot(x,w)+b
w.shape (5, 2)
ValueError: operands could not be broadcast together with shapes (7,2) (7,)
type(w) <class 'numpy.ndarray'>
b.shape (7,)
type(b) <class 'numpy.ndarray'>
注意!!!
此时的b的shape是(7, ),既不是行向量,也不是列向量。貌似跟上一种情况类似,但是却不能广播。说明这种奇怪的矩阵只有当列的维度匹配时才能广播,正如上一种情况。
所以,日后尽量不用这种 类似(7,)的shape的矩阵。