多光谱影像isodata非监督分类

时间:2024-03-17 21:30:45

title: 遥感多光谱影像isodata非监督分类
date: 2018-10-03
categories: 遥感图像处理
tags:
- python
- 图像处理
- gdal

多光谱影像 isodata 非监督分类

ISODATA全称为 Iterative Selforganizing Data Analysis Techniques Algorithm

特点

ISODATA是一种非监督的分类方法

不需要知道集群的数量

迭代分裂并合并产生新的聚类

用户定义阈值参数

迭代直到达到阈值为止

原理

1)随机放置聚类中心
基于最小距离中心法对像元聚类
2)计算每一个聚类的标准差以及各个类中心间的距离(类间距)

聚类标准差偏差大于用户定义阈值,分裂

类间距小于用户定义的阈值,则合并

3)进行第二次迭代
产生新的聚类中心

4)进行进一步的迭代,直到:
i)平均类内距低于用户定义的阈值,
ii)两次迭代过程类内距的平均变化小于阈值,或
iii)迭代的最大次数达到

输入参数

K:初始聚类中心个数;

TN:每一类中允许的最少样本数目(若少于此数,就不能单独成为一类);

TS:类内各特征分量分布的相对标准差上限(大于此数就分裂);

TC:两类中心间的最小距离下限(若小于此数,这两类应合并);

L:在每次迭代中最多可以进行“合并”操作的次数 ;

I:允许的最多迭代次数。

代码实现

基于GITHUB项目 lucka-me/ISODATA-python修改,原项目只支持RGB三通道图像的分类,而遥感影像分类通常是多光谱影像

以下是对TM影像的ISODATA分类实现

#!/usr/bin/env python3
# coding: utf-8
import numpy
import math
import random
import gdal
import cv2

inputFilename = "before.img"
outputFilename = "out.tif"

argvK = 8#初始类别数(期望)
argvTN = 20#每个类别中样本最小数目
argvTS = 1#每个类别的标准差
argvTC = 0.5#每个类别间的最小距离
argvL = 5#每次允许合并的最大类别对的数量
argvI = 10#迭代次数
dataset = gdal.Open("before.img")


class Pixel:
    """Pixel"""
    def __init__(self, initX: int, initY: int, initColor):
        self.x = initX
        self.y = initY
        self.color = initColor

class Cluster:
    """Cluster in Gray"""
    def __init__(self, initCenter):
        self.center = initCenter
        self.pixelList = []

class ClusterPair:
    """Cluster Pair"""
    def __init__(self, initClusterAIndex: int, initClusterBIndex: int, initDistance):
        self.clusterAIndex = initClusterAIndex
        self.clusterBIndex = initClusterBIndex
        self.distance = initDistance
# RGB
def distanceBetween(colorA, colorB) -> float:
  
    dR = int(colorA[0]) - int(colorB[0])
    dG = int(colorA[1]) - int(colorB[1])
    dB = int(colorA[2]) - int(colorB[2])
    d4 = int(colorA[3]) - int(colorB[3])
    d5 = int(colorA[4]) - int(colorB[4])
    d6 = int(colorA[5]) - int(colorB[5])
    d7 = int(colorA[6]) - int(colorB[6])
    return math.sqrt((dR**2)+(dG**2)+(dB**2)+(d4**2)+(d5**2)+(d6**2)+(d7**2))
  


def main(dataset, outputfilename,K: int, TN: int, TS: float, TC:int, L: int, I: int):   
    # dataset = gdal.Open("before.img")
    im_bands = dataset.RasterCount #波段数
    imgX = dataset.RasterXSize #栅格矩阵的列数
    imgY = dataset.RasterYSize #栅格矩阵的行数
    im_geotrans = dataset.GetGeoTransform()  #仿射矩阵
    im_proj = dataset.GetProjection() #地图投影信息
    imgArray = dataset.ReadAsArray(0,0,imgX,imgY)#获取数据
    
    clusterList = []
    # 随机生成聚类中心
    for i in range(0, K):
        randomX = random.randint(0, imgX - 1)
        randomY = random.randint(0, imgY - 1)
        duplicated = False
        for cluster in clusterList:
            if (cluster.center[0] == imgArray[0,randomX, randomY] and
                cluster.center[1] == imgArray[1,randomX, randomY] and
                cluster.center[2] == imgArray[2,randomX, randomY] and
                cluster.center[3] == imgArray[3,randomX, randomY] and
                cluster.center[4] == imgArray[4,randomX, randomY] and
                cluster.center[5] == imgArray[5,randomX, randomY] and
                cluster.center[6] == imgArray[6,randomX, randomY] 
                ):
                duplicated = True
                break
        if not duplicated:
            clusterList.append(Cluster(numpy.array([imgArray[0,randomX, randomY],
                                                    imgArray[1,randomX, randomY],
                                                    imgArray[2,randomX, randomY],
                                                    imgArray[3,randomX, randomY],
                                                    imgArray[4,randomX, randomY],
                                                    imgArray[5,randomX, randomY],
                                                    imgArray[6,randomX, randomY]
                                                    ],
                                                    dtype = numpy.uint8)))

    # 开始迭代
    iterationCount = 0
    didAnythingInLastIteration = True
    while True:
        iterationCount += 1

        # 清空每一类内像元
        for cluster in clusterList:
            cluster.pixelList.clear()
        print("------")
        print("迭代第{0}次".format(iterationCount))

        #将所有像元分类
        print("分类...", end = '', flush = True)
        for row in range(0, imgX):
            for col in range(0, imgY):
                targetClusterIndex = 0
                targetClusterDistance = distanceBetween(imgArray[:,row, col], clusterList[0].center)
                # 分类
                for i in range(1, len(clusterList)):
                    currentDistance = distanceBetween(imgArray[:,row, col], clusterList[i].center)
                    if currentDistance < targetClusterDistance:
                        targetClusterDistance = currentDistance
                        targetClusterIndex = i
                clusterList[targetClusterIndex].pixelList.append(Pixel(row, col, imgArray[:,row, col]))
        print(" 结束 ")

        #检查类中样本个数是否满足要求
        gotoNextIteration = False
        for i in range(len(clusterList) - 1, -1, -1):
            if len(clusterList[i].pixelList) < TN:
                # 重新分类
                clusterList.pop(i)
                gotoNextIteration = True
                break
        if gotoNextIteration:
            print("样本个数不满足要求")
            continue
        print("样本个数满足要求")

        # 重新计算聚类中心
        print("重新计算聚类中心...", end = '', flush = True)
        for cluster in clusterList:
            sumR = 0.0
            sumG = 0.0
            sumB = 0.0
            sum4 = 0.0
            sum5 = 0.0
            sum6 = 0.0
            sum7 = 0.0
           
            for pixel in cluster.pixelList:
                sumR += int(pixel.color[0])
                sumG += int(pixel.color[1])
                sumB += int(pixel.color[2])
                sum4 += int(pixel.color[3])
                sum5 += int(pixel.color[4])
                sum6 += int(pixel.color[5])
                sum7 += int(pixel.color[6])
            aveR = round(sumR / len(cluster.pixelList))
            aveG = round(sumG / len(cluster.pixelList))
            aveB = round(sumB / len(cluster.pixelList))
            ave4 = round(sum4 / len(cluster.pixelList))
            ave5 = round(sum5 / len(cluster.pixelList))
            ave6 = round(sum6 / len(cluster.pixelList))
            ave7 = round(sum7 / len(cluster.pixelList))
            
            if (aveR != cluster.center[0] and
                aveG != cluster.center[1] and
                aveB != cluster.center[2] and
                ave4 != cluster.center[3] and
                ave5 != cluster.center[4] and
                ave6 != cluster.center[5] and
                ave7 != cluster.center[6] 
                ):
                didAnythingInLastIteration = True
            cluster.center = numpy.array([aveR, aveG, aveB, ave4,ave5,ave6,ave6,ave7], dtype = numpy.uint8)
        print("结束")
        if iterationCount > I:
            break
        if not didAnythingInLastIteration:
            print("更多迭代次数是不是必要的")
            break

        # 计算平均距离
        print("准备合并或分裂...", end = '', flush = True)
        aveDisctanceList = []
        sumDistanceAll = 0.0
        for cluster in clusterList:
            currentSumDistance = 0.0
            for pixel in cluster.pixelList:
                currentSumDistance += distanceBetween(pixel.color, cluster.center)
            aveDisctanceList.append(float(currentSumDistance) / len(cluster.pixelList))
            sumDistanceAll += currentSumDistance
        aveDistanceAll = float(sumDistanceAll) / (imgX * imgY)
        print(" 结束")

        if (len(clusterList) <= K / 2) or not (iterationCount % 2 == 0 or len(clusterList) >= K * 2):
            # 分裂
            print("开始分裂", end = '', flush = True)
            beforeCount = len(clusterList)
            for i in range(len(clusterList) - 1, -1, -1):
                currentSD = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
                for pixel in clusterList[i].pixelList:
                    currentSD[0] += (int(pixel.color[0]) - int(clusterList[i].center[0])) ** 2
                    currentSD[1] += (int(pixel.color[1]) - int(clusterList[i].center[1])) ** 2
                    currentSD[2] += (int(pixel.color[2]) - int(clusterList[i].center[2])) ** 2
                    currentSD[3] += (int(pixel.color[3]) - int(clusterList[i].center[3])) ** 2
                    currentSD[4] += (int(pixel.color[4]) - int(clusterList[i].center[4])) ** 2
                    currentSD[5] += (int(pixel.color[5]) - int(clusterList[i].center[5])) ** 2
                    currentSD[6] += (int(pixel.color[6]) - int(clusterList[i].center[6])) ** 2
                currentSD[0] = math.sqrt(currentSD[0] / len(clusterList[i].pixelList))
                currentSD[1] = math.sqrt(currentSD[1] / len(clusterList[i].pixelList))
                currentSD[2] = math.sqrt(currentSD[2] / len(clusterList[i].pixelList))
                currentSD[3] = math.sqrt(currentSD[3] / len(clusterList[i].pixelList))
                currentSD[4] = math.sqrt(currentSD[4] / len(clusterList[i].pixelList))
                currentSD[5] = math.sqrt(currentSD[5] / len(clusterList[i].pixelList))
                currentSD[6] = math.sqrt(currentSD[6] / len(clusterList[i].pixelList))
               
                # 计算各波段最大标准差
                # Find the max in SD of R, G and B
                maxSD = currentSD[0]
                for j in (1, 2):
                    maxSD = currentSD[j] if currentSD[j] > maxSD else maxSD
                if (maxSD > TS) and ((aveDisctanceList[i] > aveDistanceAll and len(clusterList[i].pixelList) > 2 * (TN + 1)) or (len(clusterList) < K / 2)):
                    gamma = 0.5 * maxSD
                    clusterList[i].center[0] += gamma
                    clusterList[i].center[1] += gamma
                    clusterList[i].center[2] += gamma
                    clusterList[i].center[3] += gamma
                    clusterList[i].center[4] += gamma
                    clusterList[i].center[5] += gamma
                    clusterList[i].center[6] += gamma
                  
                    clusterList.append(Cluster(numpy.array([clusterList[i].center[0],
                                                            clusterList[i].center[1],
                                                            clusterList[i].center[2],
                                                            clusterList[i].center[3],
                                                            clusterList[i].center[4],
                                                            clusterList[i].center[5],
                                                            clusterList[i].center[6],
                                                            ],
                                                            dtype = numpy.uint8)))
                    clusterList[i].center[0] -= gamma * 2
                    clusterList[i].center[1] -= gamma * 2
                    clusterList[i].center[2] -= gamma * 2
                    clusterList[i].center[3] -= gamma * 2
                    clusterList[i].center[4] -= gamma * 2
                    clusterList[i].center[5] -= gamma * 2
                    clusterList[i].center[6] -= gamma * 2
                    clusterList.append(Cluster(numpy.array([clusterList[i].center[0],
                                                            clusterList[i].center[1],
                                                            clusterList[i].center[2],
                                                            clusterList[i].center[3],
                                                            clusterList[i].center[4],
                                                            clusterList[i].center[5],
                                                            clusterList[i].center[6]
                                                            ],
                                                            dtype = numpy.uint8)))
                    clusterList.pop(i)
            print(" {0} -> {1}".format(beforeCount, len(clusterList)))
        elif (iterationCount % 2 == 0) or (len(clusterList) >= K * 2) or (iterationCount == I):
            # 合并
            print("合并:", end = '', flush = True)
            beforeCount = len(clusterList)
            didAnythingInLastIteration = False
            clusterPairList = []
            for i in range(0, len(clusterList)):
                for j in range(0, i):
                    currentDistance = distanceBetween(clusterList[i].center, clusterList[j].center)
                    if currentDistance < TC:
                        clusterPairList.append(ClusterPair(i, j, currentDistance))

            clusterPairListSorted = sorted(clusterPairList, key = lambda clusterPair: clusterPair.distance)
            newClusterCenterList = []
            mergedClusterIndexList = []
            mergedPairCount = 0
            for clusterPair in clusterPairList:
                hasBeenMerged = False
                for index in mergedClusterIndexList:
                    if clusterPair.clusterAIndex == index or clusterPair.clusterBIndex == index:
                        hasBeenMerged = True
                        break
                if hasBeenMerged:
                    continue
                newCenterR = int((len(clusterList[clusterPair.clusterAIndex].pixelList) * float(clusterList[clusterPair.clusterAIndex].center[0]) + len(clusterList[clusterPair.clusterBIndex].pixelList) * float(clusterList[clusterPair.clusterBIndex].center[0])) / (len(clusterList[clusterPair.clusterAIndex].pixelList) + len(clusterList[clusterPair.clusterBIndex].pixelList)))
                newCenterG = int((len(clusterList[clusterPair.clusterAIndex].pixelList) * float(clusterList[clusterPair.clusterAIndex].center[1]) + len(clusterList[clusterPair.clusterBIndex].pixelList) * float(clusterList[clusterPair.clusterBIndex].center[1])) / (len(clusterList[clusterPair.clusterAIndex].pixelList) + len(clusterList[clusterPair.clusterBIndex].pixelList)))
                newCenterB = int((len(clusterList[clusterPair.clusterAIndex].pixelList) * float(clusterList[clusterPair.clusterAIndex].center[2]) + len(clusterList[clusterPair.clusterBIndex].pixelList) * float(clusterList[clusterPair.clusterBIndex].center[2])) / (len(clusterList[clusterPair.clusterAIndex].pixelList) + len(clusterList[clusterPair.clusterBIndex].pixelList)))
                newCenter4 = int((len(clusterList[clusterPair.clusterAIndex].pixelList) * float(clusterList[clusterPair.clusterAIndex].center[3]) + len(clusterList[clusterPair.clusterBIndex].pixelList) * float(clusterList[clusterPair.clusterBIndex].center[3])) / (len(clusterList[clusterPair.clusterAIndex].pixelList) + len(clusterList[clusterPair.clusterBIndex].pixelList)))
                newCenter5 = int((len(clusterList[clusterPair.clusterAIndex].pixelList) * float(clusterList[clusterPair.clusterAIndex].center[4]) + len(clusterList[clusterPair.clusterBIndex].pixelList) * float(clusterList[clusterPair.clusterBIndex].center[4])) / (len(clusterList[clusterPair.clusterAIndex].pixelList) + len(clusterList[clusterPair.clusterBIndex].pixelList)))
                newCenter6 = int((len(clusterList[clusterPair.clusterAIndex].pixelList) * float(clusterList[clusterPair.clusterAIndex].center[5]) + len(clusterList[clusterPair.clusterBIndex].pixelList) * float(clusterList[clusterPair.clusterBIndex].center[5])) / (len(clusterList[clusterPair.clusterAIndex].pixelList) + len(clusterList[clusterPair.clusterBIndex].pixelList)))
                newCenter7 = int((len(clusterList[clusterPair.clusterAIndex].pixelList) * float(clusterList[clusterPair.clusterAIndex].center[6]) + len(clusterList[clusterPair.clusterBIndex].pixelList) * float(clusterList[clusterPair.clusterBIndex].center[6])) / (len(clusterList[clusterPair.clusterAIndex].pixelList) + len(clusterList[clusterPair.clusterBIndex].pixelList)))
                
                newClusterCenterList.append([newCenterR, newCenterG, newCenterB,newCenter4, newCenter5, newCenter6,newCenter7])
                mergedClusterIndexList.append(clusterPair.clusterAIndex)
                mergedClusterIndexList.append(clusterPair.clusterBIndex)
                mergedPairCount += 1
                if mergedPairCount > L:
                    break
            if len(mergedClusterIndexList) > 0:
                didAnythingInLastIteration = True
            mergedClusterIndexListSorted = sorted(mergedClusterIndexList, key = lambda clusterIndex: clusterIndex, reverse = True)
            for index in mergedClusterIndexListSorted:
                clusterList.pop(index)
            for center in newClusterCenterList:
                clusterList.append(Cluster(numpy.array([center[0], center[1], center[2],center[3], center[4], center[5],center[6]], dtype = numpy.uint8)))
            print(" {0} -> {1}".format(beforeCount, len(clusterList)))

    # 生成新的图像矩阵
    print("分类结束")
    print("一共分为 {0} 类.".format(len(clusterList)))
    newImgArray = numpy.zeros((7,imgX, imgY), dtype = numpy.uint8)
    for cluster in clusterList:
        for pixel in cluster.pixelList:
            newImgArray[0,pixel.x, pixel.y] = int(cluster.center[0])
            newImgArray[1,pixel.x, pixel.y] = int(cluster.center[1])
            newImgArray[2,pixel.x, pixel.y] = int(cluster.center[2])
            newImgArray[3,pixel.x, pixel.y] = int(cluster.center[3])
            newImgArray[4,pixel.x, pixel.y] = int(cluster.center[4])
            newImgArray[5,pixel.x, pixel.y] = int(cluster.center[5])
            newImgArray[6,pixel.x, pixel.y] = int(cluster.center[6])

    a2 = numpy.ones((3,imgX,imgY), dtype=numpy.uint8)  
    

    unic = numpy.unique(newImgArray[0])
    color = []
    print("对各个类别进行颜色渲染...")
    for i in range(len(unic)): 
        color.append([random.randint(0, 128),random.randint(0, 255),random.randint(128, 255)])


    for i in range(imgY):
        for j in range(imgX):
            for k in range(len(unic)):
                if(newImgArray[0,i,j] == unic[k]):
                    a2[0,i,j] = color[k][0]
                    a2[1,i,j] = color[k][1]
                    a2[2,i,j] = color[k][2]
    
            
        
    print("写出分类后专题图")
    driver = gdal.GetDriverByName("GTiff")
    IsoData = driver.Create(outputfilename, imgX, imgY, 3, gdal.GDT_Byte)
    # for i in range(3):
    #     IsoData.GetRasterBand(i+1).WriteArray(newImgArray[i])
    print("设置坐标参数")
    IsoData.SetGeoTransform(im_geotrans)              #写入仿射变换参数
    print("设置投影信息")
    IsoData.SetProjection(im_proj)   #写入投影
    for i in range(3):                 
        IsoData.GetRasterBand(i+1).WriteArray(a2[i])
    
    del dataset
    print("ISODATA非监督分类完成")
   

if __name__ == '__main__':
    main(dataset,outputFilename,argvK,argvTN,argvTS,argvTC,argvL,argvI)

结果

多光谱影像isodata非监督分类