
时间:2022-07-22 14:55:17

I want to automate the threshold process in hierarchical clustering process, What i want to do is , instead of inputting threshold value manually , How do i check if i have clusters in range of 30 to 50 , if clusters are not in range of 30-50 , change the threshold value through code , by 0.1 or 0.2 in python

    import pickle
    import re
    import string
    import sys
    # import gensim
    # from gensim import corpora
    from time import time

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import scipy.cluster.hierarchy as sch
    from nltk.corpus import stopwords
    from nltk.stem.wordnet import WordNetLemmatizer
    from scipy.cluster.hierarchy import dendrogram, linkage
    from scipy.spatial.distance import pdist
    from scipy.spatial.distance import squareform
    from sklearn.decomposition import NMF, LatentDirichletAllocation
    from sklearn.feature_extraction.text import CountVectorizer
    from sklearn.feature_extraction.text import TfidfVectorizer
    from stop_word_complaints import complaint_stop_words

 tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 2), max_df=0.95, min_df=1, token_pattern=r'\b\w+\b',
                                       max_features=n_features, stop_words=list(stop), analyzer='word')
    X = tfidf_vectorizer.fit_transform(corpus).toarray()

    non_zero_features = np.where(np.sum(X, axis=1) != 0)[0]
    print("done in %0.3fs." % (time() - t0))
    print("pdist ...")
    t0 = time()
    cos_dist = pdist(X[non_zero_features, :], 'cosine')
    print("done in %0.3fs." % (time() - t0))
    dists = np.asarray(squareform(cos_dist))
    dists[np.isnan(dists)] = 1
    # cos_dist[np.isnan(cos_dist)] = 0
    # dists[np.argwhere(np.isnan(dists))] = 1
    print("linkage ...")
    np.savetxt(str_path + "_dist_1.csv", dists, delimiter=',')
    # pickle.dump(dists, open(str_path + "_dist.p", "wb"))
    t0 = time()
    linkage_matrix = linkage(dists, "average")
    print("done in %0.3fs." % (time() - t0))
    np.savetxt(str_path + "linkage_matrix.csv", linkage_matrix, delimiter=',')
    # linkage_matrix = np.loadtxt(str_path + "linkage_matrix.csv", delimiter=',')
    # pickle.dump(linkage_matrix, open(str_path + "linkage_matrix.p", "wb"))
    # create figure & 1 axis
    fig, ax = plt.subplots(nrows=1, ncols=1)  # create figure & 1 axis

    plt.title('Hierarchical Clustering Dendrogram')
    plt.xlabel('sample index')
        # leaf_rotation=90.,  # rotates the x axis labels
        # leaf_font_size=3.,  # font size for the x axis labels
    fig.savefig(str_path + 'Agglo_Heirachy_dendo.png')  # save the figure to file

min_th = min(linkage_matrix[:,2])
max_th = max(linkage_matrix[:,2])
clusters =  get_clusters(linkage_matrix, min_th, max_th)

1 个解决方案



I finally got the solution which is I have defined new function where i get the required clusters within range


def get_clusters(linkage_matrix, min_th, max_th):
    while (True):
        th = min_th + (max_th - min_th) / 2
        clusters = sch.fcluster(linkage_matrix, th, 'distance')
        if  max(clusters) >= 30 and  max(clusters) <= 50:
            print("Clusters found: %d" % max(clusters))
            return clusters

        elif  max(clusters) > 50:
            min_th = th
            print("Clusters found: %d" % max(clusters))

        elif  max(clusters) < 30:
            max_th = th
            print("Clusters found: %d" % max(clusters))



