1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x_data = np.array([1.0, 2.0, 3.0])
y_data = np.array([5.0, 8.0, 11.0])
def forward(x):
return x * w + b
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
W = np.arange(0.0, 4.1, 0.1)
B = np.arange(0.0, 4.1, 0.1)
w, b = np.meshgrid(W, B)
loss_sum = np.zeros_like(w)
for x_val, y_val in zip(x_data, y_data):
loss_val = loss(x_val,y_val)
loss_sum += loss_val
mse = loss_sum / len(x_data)
fig = plt.figure() # 创建一个新的窗口
ax = fig.add_subplot(111,projection='3d') #
ax.plot_surface(w,b,mse,cmap='viridis') # camp='viridis'指定颜色映射
ax.set_xlabel('w')
ax.set_ylabel('b')
ax.set_zlabel('MSE')
plt.show()
|