算法导论15.4-6 最长递增子序列(nlogn)

时间:2022-03-26 09:51:45

         设序列X(n),最长递增子序列长度为m,考虑长度为i的递增子序列,这种序列有多个,最小的末尾元素记为L(i),可以得到 L(1) <= L(2) <= ... <= L(m),这个证明较简单,使用反证法即可。在这个递增的序列中使用十分法查找,则可以实现O(nlogn)的算法。

从左到右扫描序列X(n),L(1) 初始为x(1),再引入一个当前最大长度K,初始为1,K表示目前扫描过的序列包含的最长递增子序列的长度,此时L(1)...L(K)有意义。读入一个数据x,如果L(i)是大于x中最小的,则x也是一个递增子序列的末尾元素,按照L(i)的定义,则L(i) = x,如此可以将数据x替换L(i),以下是函数实现:

#define NO_STRICTLY
static inline int find_pos(int *B, int len, int value)
{
int left = 0, right = len - 1, middle = 0;

while (left <= right) {
middle = (left + right)>>1;
if (B[middle] < value)
left = middle + 1;
#ifdef NO_STRICTLY
else if (B[middle] == value)
left = middle + 1;
#endif
else
right = middle - 1;
}

return left;
}


用构造法证明,由于L(i)是大于x中最小的,则x>=L(i-1),按照L(i-1)的定义,将x添加到L(i-1)的i-1 长度的递增子序列的末尾,就构造了一个长度为i的递增子序列。

形象的来讲,每读入一个数,就尝试将此数放置到一个递增序列的末尾,构造出一个更长的递增序列。

 

详细步骤参考如下的例子:

序列 2 1 3 0 4 1 5 2 7

 

L1

L2

L3

L4

L5

开始

2

 

 

 

 

读入1

1

 

 

 

 

读入3

1

3

 

 

 

读入0

0

3

 

 

 

读入4

0

3

4

 

 

读入1

0

1

4

 

 

读入5

0

1

4

5

 

读入2

0

1

2

5

 

读入7

0

1

2

5

7

 

以上过程有如下几个特征

1.        L的值只会越变越小,这是最自然的,因为L(i)就是长度为i的递增子序列中的最小末尾元素

2.        读入一个数据x,会覆盖某个L,这个L是>=x中最小的,如果x大于所有L,则新生成一个L。读入每个数据,L都是一个递增序列

3.        由于L是递增序列,插入数据时可以使用二分法进行,所以每个输入字符时间复杂度为O(logn),整体时间复杂度为O(nlogn)

4.        L的长度即为序列X(n)的最长递增子序列的长度

5.        从最左下角的L开始,按照“往上、往左”方向就会输出最长递增子序列的内容。优先往上,如果上方数据和当前数据相同,如果上方数据不同则转向左,上图最后输出的递增子序列为“1 3 4 5 7”

代码如下:

static int lis_nlogn_old(int *p, int len, int *inc_seq)
{
int i = 0;
int *L = NULL;
int **seq = NULL;
int pos = 0, curr_len = 0;

seq = malloc(len * sizeof(int *));
seq[0] = malloc(len * sizeof(int) * len);
for (i = 1;i < len;i++) {
seq[i] = seq[0] + len * i;
}

L = malloc((len + 1) * sizeof(int));

L[0] = p[0];
curr_len = 1;
seq[0][0] = p[0];

for (i = 1;i < len;i++) {
pos = find_pos(L, curr_len, p[i]);
L[pos] = p[i];
if (pos > 0) {
memcpy(seq[pos], seq[pos - 1], pos * sizeof(int));
seq[pos][pos] = p[i];
}
else {
seq[0][0] = p[i];
}
if (pos + 1 > curr_len)
curr_len++;
}

free(L); free(seq[0]);free(seq);

memcpy(inc_seq, seq[curr_len - 1], curr_len * sizeof(int));
return curr_len;
}

很遗憾,虽然上述算法可以在O(nlogn)的时间内得到最长子序列的长度,但无法得到整个子序列,原因是每读入一个新数据,就需要将前一个数据保存的L值复制过来,考虑这部分的时间,就会发现整体复杂度为O(n*n)

 

注意整个子序列不能通过L数组来获得,因为某个L[i]会在后续被修改过了,为获取整个子序列,需要保存每个元素的前驱元素,即当前元素所属于的递增子序列的前一个元素,注意到每个元素的前驱元素一旦确定,就不会改变,所以可以根据这个前驱关系确定最长子序列。下面是更新后的代码

static int lis_nlogn(int *p, int len, int *inc_seq)
{
int i = 0, pos = 0, curr_len = 0;
int *L = NULL, *prev = NULL, *M = NULL;

L = malloc(len * sizeof(int));
M = malloc(len * sizeof(int));
prev = malloc(len * sizeof(int));

L[0] = p[0];
M[0] = 0;
prev[0] = -1; /* the prev of the p[0] is NULL */
curr_len = 1;

/* Caculate prev and M */
for (i = 1;i < len;i++) {
pos = find_pos(L, curr_len, p[i]);
L[pos] = p[i];
M[pos] = i;
if (pos > 0)
prev[i] = M[pos - 1];
else
prev[i] = -1;
if (pos + 1 > curr_len)
curr_len++;
}

/* Output increasing sequence */
pos = M[curr_len - 1];
for (i = curr_len - 1;i >= 0 && pos != -1;i--) {
inc_seq[i] = p[pos];
pos = prev[pos];
}

return curr_len;
}