jiehun_DEMO

时间:2024-10-21 20:22:14
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt


# 定义卷积自编码器模型
def build_autoencoder(conv_filters=48, learning_rate=0.001, patch_size=5, num_bands=330):
    input_layer = layers.Input(shape=(patch_size, patch_size, num_bands))

    # diff
    # x = layers.Conv2D(conv_filters, (3, 3), activation='leaky_relu', padding='same')(input_layer)
    x = layers.Conv2D(conv_filters, (3, 3), padding='same')(input_layer)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # diff
    # x = layers.Conv2D(16, (1, 1), activation='leaky_relu', padding='same')(x)
    x = layers.Conv2D(16, (1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    abundance_layer = layers.Lambda(lambda x: tf.nn.softmax(x * 3.5), output_shape=lambda input_shape: input_shape)(x)

    decoded = layers.Conv2D(num_bands, (3, 3), activation='linear', padding='same')(abundance_layer)

    autoencoder = models.Model(inputs=input_layer, outputs=decoded)

    autoencoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss='mse')

    return autoencoder


# 数据生成器函数
def preprocess_data_generator(hsi_data, patch_size=5):
    H, W, B = hsi_data.shape
    for i in tqdm(range(0, H - patch_size + 1, patch_size), desc="Generating patches"):
        for j in range(0, W - patch_size + 1, patch_size):
            patch = hsi_data[i:i + patch_size, j:j + patch_size, :].astype(np.float32)
            yield patch.reshape(1, patch_size, patch_size, B)


# 手动定义超参数
conv_filters = 64  # 卷积滤波器数量
# diff
# learning_rate = 0.001  # 学习率
learning_rate = 0.0001  # 学习率

# 读取和处理高光谱数据(这里假设你已经加载了数据)
hsi_data = np.random.rand(2432, 2372, 330)  # 模拟数据,替换为真实的高光谱数据

# 构建卷积自编码器
autoencoder = build_autoencoder(conv_filters=conv_filters, learning_rate=learning_rate, num_bands=hsi_data.shape[2])

# 打印模型结构
autoencoder.summary()

# 准备数据
data_generator = preprocess_data_generator(hsi_data)

# 计算每个 epoch 的步骤
steps_per_epoch = (hsi_data.shape[0] // 5) * (hsi_data.shape[1] // 5) // 32

# 清理计算图
tf.keras.backend.clear_session()

# 训练模型
for epoch in range(10):  # 设置总的 epoch 数量
    print(f"Epoch {epoch + 1}/10")
    for step in tqdm(range(steps_per_epoch)):
        x_batch = next(data_generator)
        autoencoder.train_on_batch(x_batch, x_batch)


# 获取丰度图
def get_abundance_maps(model, hsi_data, patch_size=5):
    abundance_maps = []
    H, W, B = hsi_data.shape
    for i in tqdm(range(0, H - patch_size + 1, patch_size), desc="Extracting abundance maps"):
        for j in range(0, W - patch_size + 1, patch_size):
            patch = hsi_data[i:i + patch_size, j:j + patch_size, :].astype(np.float32)
            abundance_map = model.predict(patch.reshape(1, patch_size, patch_size, B))
            abundance_maps.append(abundance_map.reshape(patch_size, patch_size, B))

    return np.array(abundance_maps)


abundance_maps = get_abundance_maps(autoencoder, hsi_data)


# 展示丰度图
def display_abundance_maps(abundance_maps, num_bands):
    plt.figure(figsize=(15, 15))
    for i in range(num_bands):
        plt.subplot(10, 10, i + 1)
        plt.imshow(abundance_maps[i], cmap='jet')
        plt.axis('off')
        plt.title(f'Band {i + 1}')
    plt.tight_layout()
    plt.show()


# 假设我们只展示前 10 个丰度图
display_abundance_maps(abundance_maps, 10)


# 进行聚类分析
def cluster_abundance_maps(abundance_maps, num_clusters=5):
    reshaped_abundance = abundance_maps.reshape(-1, abundance_maps.shape[2])
    # diff
    # kmeans = KMeans(n_clusters=num_clusters)
    kmeans = KMeans(n_clusters=num_clusters, init='k-means++')

    print("Clustering abundance maps...")
    kmeans.fit(reshaped_abundance)
    cluster_labels = kmeans.labels_.reshape(abundance_maps.shape[0], abundance_maps.shape[1])
    return cluster_labels


# 聚类岩性
num_clusters = 5
cluster_labels = cluster_abundance_maps(abundance_maps, num_clusters)

# 可视化聚类结果
plt.figure(figsize=(8, 8))
plt.imshow(cluster_labels, cmap='jet')
plt.title('Clustered Lithology Map')
plt.axis('off')
plt.show()

相关文章