Python Logistic 回归分类

时间:2021-03-13 23:52:29

Logistic回归可以认为是线性回归的延伸,其作用是对二分类样本进行训练,从而对达到预测新样本分类的目的。
假设有一组已知分类的MxN维样本X,M为样本数,N为特征维度,其相应的已知分类标签为Mx1维矩阵Y。那么Logistic回归的实现思路如下:
(1)用一组权重值W(Nx1)对X的特征进行线性变换,得到变换后的样本X’(Mx1),其目标是使属于不同分类的样本X’存在一个明显的一维边界。
(2)然后再对样本X’进一步做函数变换,从而使处于一维边界两测的值变换到相应的范围之内。
(3)训练过程就是通过改变W尽可能使得到的值位于一维边界两侧,并且与已知分类相符。
(4)对于Logistic回归,就是将原样本的边界变换到x=0这个边界。
下面是Logistic回归的典型代码:



# -*- coding: utf-8 -*-
"""
Created on Wed Nov 09 15:21:48 2016
Logistic回归分类
"""

import numpy  as np


class LogisticRegressionClassifier ( ):
    
     def  __init__ ( self ):
         self._alpha  =  None
    

     #定义一个sigmoid函数
     def _sigmoid ( self , fx ):
         return  1.0/ ( 1 + np. exp (-fx ) )

     #alpha为步长(学习率);maxCycles最大迭代次数
     def _gradDescent ( self , featData , labelData , alpha , maxCycles ):
        dataMat  = np. mat (featData )                       #size: m*n
        labelMat  = np. mat (labelData ). transpose ( )         #size: m*1
        m , n  = np. shape (dataMat )
        weigh  = np. ones ( (n ,  1 ) ) 
         for i  in  range (maxCycles ):
            hx  =  self._sigmoid (dataMat * weigh )
            error  = labelMat - hx        #size:m*1
            weigh  = weigh + alpha * dataMat. transpose ( ) * error #根据误差修改回归系数
         return weigh

     #使用梯度下降方法训练模型,如果使用其它的寻参方法,此处可以做相应修改
     def fit ( self , train_x , train_y , alpha = 0.01 , maxCycles = 100 ):
         return  self._gradDescent (train_x , train_y , alpha , maxCycles )

     #使用学习得到的参数进行分类
     def predict ( self , test_X , test_y , weigh ):
        dataMat  = np. mat (test_X )
        labelMat  = np. mat (test_y ). transpose ( )   #使用transpose()转置
        hx  =  self._sigmoid (dataMat*weigh )   #size:m*1
        m  =  len (hx )
        error  =  0.0
         for i  in  range (m ):
             if  int (hx [i ] )  >  0.5:
                 print  str (i+ 1 )+ '-th sample ' ,  int (labelMat [i ] ) ,  'is classfied as: 1' 
                 if  int (labelMat [i ] )  !=  1:
                    error + =  1.0
                     print  "classify error."
             else:
                 print  str (i+ 1 )+ '-th sample ' ,  int (labelMat [i ] ) ,  'is classfied as: 0' 
                 if  int (labelMat [i ] )  !=  0:
                    error + =  1.0
                     print  "classify error."
        error_rate  = error/m
         print  "error rate is:" ,  "%.4f" %error_rate
         return error_rate