hihoCoder 403 Forbidden 字典树

时间:2022-04-17 09:12:42

题意:给定个规则,个ip,问这些ip是否能和某个规则匹配,如果有多个规则,则匹配第一个。如果没能匹配成功,则认为是”allow”,否则根据规则决定是”allow”或者”deny”.


思路:字典树,将所有ip全部转换为01串,在字典树上面插入查找。

在求二进制时,注意补前缀0。

坑点:

1.字典树上每一个规则都应该带一个序号,表示这是第几个规则,因为题目要求第一个匹配,那么当你利用ip去寻找匹配的规则时,应该是找到序号最小的。

2.出现deny 0.0.0.0/0,说明所有无法匹配ip的都是deny,因为这些ip和这条规则匹配。

给大家贴几个测试数据:

5 2

deny 0.0.0.0/0

allow 0.0.0.0/0

deny 0.0.0.0/1

deny 0.0.0.0/2

allow 123.234.12.23/3

123.234.12.23

0.234.12.23

答案: NO NO

2 1

deny 1.1.1.1

allow 127.0.0.1

1.1.1.222

答案:YES

AC代码

#include <cstdio>
#include <cmath>
#include <cctype>
#include <bitset>
#include <algorithm>
#include <cstring>
#include <utility>
#include <string>
#include <iostream>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
#define eps 1e-10
#define inf 0x3f3f3f3f
#define PI pair<int, int>
typedef long long LL;
const int maxn = 1e5 + 5;
bool is_zero, tag;
struct node{
    int ok, order;
    node *nex[2];
    node() {
        ok = -1; //从未访问过
        nex[0] = nex[1] = NULL;
    }
}*root;
void init() {
    root = new node();
}

void insert(string &s, int n, int ok, int order) {
    node *p = root, *q;
    for(int i = 0; i < n; ++i) {
        int u = s[i] - '0';
        if(p->nex[u] == NULL) {
            q = new node();
            p->nex[u] = q;
        }
        p = p->nex[u];
        if(i == n-1 && p->ok == -1) {
            p->ok = ok;
            p->order = order;
        }
    }
}

bool search(string &s) {
    node *p = root;
    bool ok = true;
    int order = inf;
    for(int i = 0; i < s.size(); ++i) {
        int u = s[i]-'0';
        if(p->nex[u] == NULL) break;
        else {
            if(p->nex[u]->ok != -1 && p->nex[u]->order < order) {
                order = p->nex[u]->order;
                ok = p->nex[u]->ok;
            }
        }
        p = p->nex[u];
    }
    //printf("%d\n", order);
    if(is_zero && order == inf) return tag;
    return ok;
}

string get_binary(int x) {
    stack<int>sta;
    do{
        sta.push(x%2);
        x /= 2;
    }while(x);
    string ans = "";
    //补前缀0
    for(int i = 0; i < 8 - sta.size(); ++i) {
        ans += '0';
    }
    //得到二进制
    while(!sta.empty()) {
        ans += sta.top() + '0';
        sta.pop();
    }
    return ans;
}

void deal(char *s, int ok, int order) {
    int len = strlen(s);
    int ind = -1;
    for(int i = 0; i < len; ++i) {
        if(s[i] == '/') {
            ind = i;
            break;
        }
    }
    string ip = "";
    if(ind == -1) ind = len;
    for(int i = 0; i < ind;) {
        if(s[i] >= '0' && s[i] <= '9') {
            int num = 0;
            while(i < ind && s[i] >= '0' && s[i] <= '9'){
                num = num * 10 + (s[i] - '0');
                ++i;
            }
            ip += get_binary(num);
        }
        else ++i;
    }
    int n = 32;
    if(ind != -1 && ind != len) {
        int num = 0;
        for(int i = ind+1; i < len; ++i) {
            num = num * 10 + s[i] -'0';
        }
        if(num == 0 && !is_zero) {
            is_zero = 1;
            tag = ok;
        }
        n = num;
    }
    insert(ip, n, ok, order);
}
int main() {
    init();
    char kind[10], ip[40];
    int n, m;
    while(scanf("%d%d", &n, &m) == 2) {
        is_zero = 0;
        for(int i = 0; i < n; ++i) {
            scanf("%s %s", kind, ip);
            int ok = 0;
            if(kind[0] == 'a') ok = 1;
            deal(ip, ok, i);
        }
        for(int i = 0; i < m; ++i) {
            scanf("%s", ip);
            int x1, x2, x3, x4;
            sscanf(ip, "%d.%d.%d.%d", &x1, &x2, &x3, &x4);
            string res = get_binary(x1) + get_binary(x2) + get_binary(x3) + get_binary(x4);
            if(search(res)) printf("YES\n");
            else printf("NO\n");
        }
    }
    return 0;
} 

如有不当之处欢迎指出!