Python学习笔记3-绘图

时间:2022-12-27 07:29:27

Roc曲线

def plot_objs_roc(fname,title=u'ROC曲线'):
    print fname
    pred=pd.read_csv(fname)
    #print pred.head(5)
    y_pred=pred['preictal']
    objs_y_pred=get_objects_y(pred)
    #print objs_y_pred['Dog_2'].head(5)
    total_auc=roc_auc_score(Y_True,y_pred)
    print 'Total AUC: %s'%(total_auc)
    plt.figure(figsize=(10,7))
    #plot roc
    for obj in OBJS:
        #通过roc_curve()函数,求出fpr和tpr,以及阈值 
        fpr, tpr, thresholds = roc_curve(Y_Objs[obj],objs_y_pred[obj])
        obj_auc = auc(fpr, tpr)
        #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来,size=15,fontproperties=zhfont
        plt.plot(fpr, tpr, lw=2, label='%s(auc = %0.3f)'% (obj,obj_auc))
    #画对角线
    plt.plot([0, 1], [0, 1],'--',color=(0.6, 0.6, 0.6), label='Random')
    fpr, tpr, thresholds = roc_curve(Y_True,y_pred)
    #ROC曲线下的面积total
    obj_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr,label='Total(auc = %0.3f)'% (total_auc),lw=4)
    plt.xlim([-0.05, 1.05])  
    plt.ylim([-0.05, 1.05])  
    plt.xlabel(u'假阳性率',size=13,fontproperties=zhfont)  
    plt.ylabel(u'真阳性率',size=13,fontproperties=zhfont)  
    plt.title(title,size=17,fontproperties=zhfont)  
    plt.legend(loc="lower right")  
    plt.show()

特异性曲线

def plot_Specificity(datasets,title=u'特异性曲线'):
    data_names=datasets.keys()
    print data_names
    plt.figure(figsize=(10,7))
    #plot roc
    t_list = np.arange(0.0, 1.0, 0.01)
    for name in data_names:
        sn, sp = [], []
        for t in t_list:
            y_class=Y_True.copy()
            y_class[datasets[name]>=t]=1
            cm=confusion_matrix(Y_True,y_class)
            sn_t = 1.0 * cm[1, 1] / (cm[1, 1] + cm[1, 0])
            sp_t = 1.0 * cm[0, 0] / (cm[0, 0] + cm[0, 1])
            sn.append(sn_t)
            sp.append(sp_t)
        plt.plot(t_list,sp, label=name, lw=2)
    plt.xlim([-0.05, 1.05])  
    plt.ylim([-0.05, 1.05])  
    plt.xlabel(u'阈值',size=13,fontproperties=zhfont)  
    plt.ylabel(u'特异性',size=13,fontproperties=zhfont)  
    plt.title(title,size=17,fontproperties=zhfont)  
    plt.legend(loc="lower right")  
    plt.show()