#coding:utf-8
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
plt.style.use('ggplot')
fig = plt.figure()
ax = Axes3D(fig)
x, y, z = [], [], []
# not used
def get_data():
for conf in range(1, 10, 1):
conf1 = str(0.1*conf)
if float(conf1) not in x:
x.append(float(conf1))
for cls_conf in range(1, 10, 1):
cls_conf1 = str(cls_conf*0.1)
if float(cls_conf1) not in y:
y.append(float(cls_conf1))
name = "sku_11.16_2stage_%s_%s.txt" % (conf1, cls_conf1)
with open(name, 'r') as f:
acc = float(f.readlines()[-1].strip().split(" ")[-1])
print('*'*50)
print(conf1, cls_conf1, acc)
z.append(acc)
return np.array(x), np.array(y), np.array(z)
def fun(x,y):
return np.power(x,2)+np.power(y,2)
# demo
x = np.linspace(0.1, 0.9, 9)
y = np.linspace(0.1, 0.9, 9)
#x, y, z = get_data()
x, y = np.meshgrid(x, y)
#z = z.reshape([9,9])
z = fun(x, y)
print("x:",x)
print("y:",y)
print("z:",z)
#plt.title("average accuracy")#总标题
ax.plot_surface(x, y, z, cmap=plt.cm.jet)
ax.set_xlabel('cls_conf', color='r')
ax.set_ylabel('conf', color='r')
ax.set_zlabel('acc', color='r')#给三个坐标轴注明
plt.show()