试求出一个字符串每一个长度为偶数的后缀在原字符串中出现的次数。
对于 100 % 100% 100% 的数据, ∣ S ∣ ≤ 200000 |S| ≤ 200000 ∣S∣≤200000。
解题思路比较简单。
对这个字符串建 AC 自动机,然后建上 fail 树。
那么一个长度为偶数的前缀在原字符串中出现的次数就是这个前缀在 Trie 上的结束节点在 fail 上的子树和。
也可以优化,意义是一样的。
AC CODE考场代码。
#includeusing namespace std; #define int long long #define _ 500000 int ans; char str[_]; int cnt, tr[_][27], tag[_], fail[_]; void insert(char *s) { int p = 0; int len = strlen(s + 1); for(int i = 1; i <= len; ++i) { int v = s[i] - 'a'; // cout << v << endl; if(!tr[p][v]) tr[p][v] = ++cnt; p = tr[p][v]; if(i % 2 == 0) tag[p] = 1; } } void getfail() { queue q; for(int i = 0; i < 26; ++i) { if(tr[0][i]) { fail[tr[0][i]] = 0; q.push(tr[0][i]); } } while(!q.empty()) { int u = q.front(); q.pop(); for(int i = 0; i < 26; ++i) { if(tr[u][i]) { fail[tr[u][i]] = tr[fail[u]][i]; q.push(tr[u][i]); } else { tr[u][i] = tr[fail[u]][i]; } } tag[u] += tag[fail[u]]; } } signed main() { scanf("%s", str + 1); insert(str); getfail(); for(int i = 0; i <= cnt; ++i) ans += tag[i]; printf("%lldn", ans); return 0; }
便于理解的代码。
#includeusing namespace std; #define int long long #define _ 500000 int ans; char str[_]; int cnt, tr[_][27], tag[_], fail[_]; void insert(char *s) { int p = 0; int len = strlen(s + 1); for(int i = 1; i <= len; ++i) { int v = s[i] - 'a'; // cout << v << endl; if(!tr[p][v]) tr[p][v] = ++cnt; p = tr[p][v]; if(i % 2 == 0) tag[p] = 1; } } int tot, head[_], to[_ << 1], nxt[_ << 1]; void add(int u, int v) { to[++tot] = v; nxt[tot] = head[u]; head[u] = tot; } void getfail() { queue q; for(int i = 0; i < 26; ++i) { if(tr[0][i]) { fail[tr[0][i]] = 0; q.push(tr[0][i]); } } while(!q.empty()) { int u = q.front(); q.pop(); for(int i = 0; i < 26; ++i) { if(tr[u][i]) { fail[tr[u][i]] = tr[fail[u]][i]; q.push(tr[u][i]); } else { tr[u][i] = tr[fail[u]][i]; } } // tag[u] += tag[fail[u]]; } for(int i = 1; i <= cnt; ++i) add(fail[i], i); } int siz[_]; void query(char *s) { int p = 0; int len = strlen(s + 1); for(int i = 1; i <= len; ++i) { int v = s[i] - 'a'; p = tr[p][v]; if(i % 2 == 0) { ans += siz[p]; } } } void dfs(int u, int fa) { siz[u] = 1; for(int i = head[u]; i; i = nxt[i]) { int v = to[i]; if(v == fa) continue; dfs(v, u); siz[u] += siz[v]; } } signed main() { scanf("%s", str + 1); insert(str); getfail(); dfs(0, -1); query(str); printf("%lldn", ans); return 0; }



