seaborn—绘制热力图

时间:2024-10-26 07:03:49

  heatmap(热力图)是识别预测变量与目标变量相关性的方法,同时,也是发现变量间是否存在多重共线性的好方法。
中文文档

seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', 
annot_kws=None,linewidths=0, linecolor='white', cbar=True, cbar_kws=None, cbar_ax=None, square=False, 
xticklabels='auto', yticklabels='auto',mask=None, ax=None, **kwargs)

这些参数具体什么含义,怎么使用见中文文档
常用参数:

  • data:矩形数据集
    可以强制转换为ndarray格式数据的2维数据集。如果提供了Pandas DataFrame数据,索引/列信息将用于标记列和行。那么如果提供了关系矩阵就可以显示变量之间的相关性。
  • vmin, vmax:浮点型数据,可选参数
    用于锚定色彩映射的值,否则它们是从数据和其他关键字参数推断出来的
  • cmap:matplotlib 颜色条名称或者对象,或者是颜色列表,可选参数
    从数据值到颜色空间的映射。 如果没有提供,默认值将取决于是否设置了“center”
  • center:浮点数,可选参数
    绘制有色数据时将色彩映射居中的值。 如果没有指定,则使用此参数将更改默认的cmap
  • annot:布尔值或者矩形数据,可选参数
    如果为True,则在每个热力图单元格中写入数据值。 如果数组的形状与data相同,则使用它来代替原始数据注释热力图
  • mask:布尔数组或者DataFrame数据,可选参数
    如果为空值,数据将不会显示在mask为True的单元格中。 具有缺失值的单元格将自动被屏蔽
  • cbar:布尔值,可选参数
    描述是否绘制颜色条
  • square:布尔值,可选参数
    如果为True,则将坐标轴方向设置为“equal”,以使每个单元格为方形
  • xticklabels, yticklabels:“auto”,布尔值,类列表值,或者整形数值,可选参数
    如果为True,则绘制数据框的列名称。如果为False,则不绘制列名称。如果是列表,则将这些替代标签绘制为xticklabels。如果是整数,则使用列名称,但仅绘制每个n标签。如果是“auto”,将尝试密集绘制不重叠的标签。
  • fmt:字符串,可选参数
    添加注释时要使用的字符串格式代码
  • annot_kws:字典或者键值对,可选参数
    当annot为True时,的关键字参数
  • ax:matplotlib Axes,可选参数
    绘制图的坐标轴,否则使用当前活动的坐标轴

返回热力图对象

使用1:关系矩阵
这个函数将矩形数据绘制为颜色编码矩阵。所以得先通过pandas中corr()方法获得关系矩阵。

# 设置绘图风格
style.use('ggplot')
sns.set_style('whitegrid')
# 设置画板尺寸
plt.subplots(figsize = (30,20))
 
# 画热力图
# 为上三角矩阵生成掩码
mask = np.zeros_like(train.corr(), dtype=np.bool)
mask[np.triu_indices_from(mask)] = True

sns.heatmap(train.corr(), 
            cmap=sns.diverging_palette(20, 220, n=200), 
            mask = mask, # 数据显示在mask为False的单元格中
            annot=True, # 注入数据
            center = 0,  # 绘制有色数据时将色彩映射居中的值
           )
# Give title. 
plt.title("Heatmap of all the Features", fontsize = 30)

在这里插入图片描述可以利用pandas中的nlargest方法,来显示排序前多少的关系矩阵。

k  = 11 # 关系矩阵中将显示10个特征
cols = num_corrmat.nlargest(k, 'SalePrice')['SalePrice'].index
cm = np.corrcoef(train_data[cols].values.T)
fig,ax = plt.subplots(figsize=(15,10))
sns.set(font_scale=1.25)
hm = sns.heatmap(cm, 
				cbar=True, 
				annot=True, # 注入数字
                 square=True, # 单元格为正方形
                 fmt='.2f',   # 字符串格式代码
                  annot_kws={'size': 10}, # 当annot为True时,的关键字参数,即注入数字的字体大小
                  yticklabels=cols.values,  # 列标签
                  xticklabels=cols.values   # 行标签
                  )
plt.show()

在这里插入图片描述


如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论留言!
在这里插入图片描述