题意:给定一个长度为N的序列,现在要求给出一个最长的序列满足序列中的元素严格上升并且相邻两个数字的下标间隔要严格大于d。
分析:
1.线段树
由于给定的元素的取值范围为0-10^5,因此维护一棵线段树,其中[l, r]的信息表示处理完前k个数时,序列最大元素落在[l, r]区间最长上升子序列的长度。从前往后处理给定的数组,处理到第 i 号元素时,更新第 i - d 号元素,这样就能够保证最长上升的序列间隔大于d,更新是需要更新到叶子节点的,但这里更新是单点更新,每次更新的位置是该元素的值,信息就是到该点的最长上升长度。
其实仔细分析可以发现这个解法其实是经典的O(n^2)的算法的改进,那个算法需要遍历之前的更新信息比较相对大小,因此也不能简单的维护前缀最值,而线段树由于节点是值信息,查询的时候就不要去检验之前大小关系,加之线段树有能够动态区间求解各种信息,时间复杂度就这么被降下来了。当然如果取值范围较大,只要N不大还能够离散化。
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <algorithm>
#define lch (p<<1)
#define rch (p<<1|1)
using namespace std; const int N = ;
int n, d;
int seq[N];
int alen[N]; struct Node {
int l, r;
int len;
}e[N*]; void build(int p, int l, int r) {
e[p].l = l, e[p].r = r, e[p].len = ;
if (l != r) {
int mid = (l + r) >> ;
build(lch, l, mid), build(rch, mid+, r);
}
} void push_up(int p) {
e[p].len = max(e[lch].len, e[rch].len);
} void modify(int p, int x, int val) {
if (e[p].l == e[p].r) e[p].len = max(e[p].len, val);
else {
int mid = (e[p].l + e[p].r) >> ;
if (x <= mid) modify(lch, x, val);
else modify(rch, x, val);
push_up(p);
}
} int query(int p, int l, int r) {
if (e[p].l == l && e[p].r == r) return e[p].len;
else {
int mid = (e[p].l + e[p].r) >> ;
if (r <= mid) return query(lch, l, r);
else if (l > mid) return query(rch, l, r);
else return max(query(lch, l, mid), query(rch, mid+, r));
}
} int main() {
while (scanf("%d %d", &n, &d) != EOF) {
build(, , ); // 建立0-10^5的线段树
int ret = ;
for (int i = ; i <= n; ++i) {
scanf("%d", &seq[i]);
if (seq[i] > ) ret = max(ret, alen[i]=query(, , seq[i]-)+);
else ret = max(ret, alen[i] = );
if (i-d>=) modify(, seq[i-d], alen[i-d]);
}
printf("%d\n", ret);
}
return ;
}
2.经典O(nlogn)LIS变种
经典的算法在数组中保留都是下标节点比当前点小的节点,因为从前往后处理也因为经典的算法其实处理的是间隔d=0的特殊情况,那么稍微进行一下推广,当我们处理完第 i 个元素只是把第 i - d 号元素放到数组中,放入的位置就是以前求出来的最长上升子序列长度,当然放入的时候要比较一下是否需要替换。
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std; const int N = ;
int n, d;
int seq[N];
int alen[N]; void solve() {
vector<int>vt;
vector<int>::iterator it;
int ret = ;
for (int i = ; i <= n; ++i) {
it = lower_bound(vt.begin(), vt.end(), seq[i]);
if (it == vt.end()) alen[i] = vt.size()+;
else alen[i] = it-vt.begin()+;
if (i-d >= ) {
if (vt.size() == alen[i-d]-) vt.push_back(seq[i-d]);
else if (vt[alen[i-d]-] > seq[i-d]) vt[alen[i-d]-] = seq[i-d];
}
ret = max(ret, alen[i]);
}
printf("%d\n", ret);
} int main() {
while (scanf("%d %d", &n, &d) != EOF) {
for (int i = ; i <= n; ++i) {
scanf("%d", &seq[i]);
}
solve();
}
return ;
}