栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

P3302 [SDOI2013]森林 主席树 + 并查集 + 离散化 + lca + 启发式合并 简短代码(y总代码风)

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

P3302 [SDOI2013]森林 主席树 + 并查集 + 离散化 + lca + 启发式合并 简短代码(y总代码风)

题目链接
这道题花费了我两小时 , 我快速写完代码 , debug了一个小时40分钟的故事, 最后一气呵成终于AC,故写题解纪念一下 。

这道题其实是一个树上求第k大树的问题。

首先我们去想想怎么求一维数组[l , r]的第k小树, 这时候我们考虑到这是主席树的模板题。

这是算法提高课的可持久化数据结构中的求第k小数
这是主席树的模板题,我当时写的丑代码。(请不要见笑,也是y总的代码风格)

#include
#include
#include

using namespace std;

const int N = 1e5 + 10 ;

struct Node{
    int l , r ;
    int cnt ; 
} tr[4 * N + N * 17 ];

int n , m ;
int root[N];
int a[N];
vector num ;
int idx ; 

int find(int x){
    return lower_bound( num.begin() , num.end() , x ) - num.begin() ;
}

int build(int l , int r){
   int q = ++idx ;
   if(l == r) return q;
   int mid = l + r >> 1 ;
   tr[q].l = build( l , mid );
   tr[q].r = build( mid + 1 , r );
   return q ;
}

int insert(int p , int l , int r , int x){//这里的 l ,r 不是数组的下标,而是值的范围大小
    int q = ++idx;
    if(l == r){
        tr[q].cnt++;
        return q;
    } 
    tr[q] = tr[p];
    int mid = l + r >> 1; 
    if(x <= mid ) tr[q].l = insert( tr[p].l , l , mid , x );
    else tr[q].r = insert( tr[p].r , mid + 1 , r , x);
    
    tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt ;
    return q;
}

int query(int p , int q , int l , int r , int  k)//这里的 l ,r 不是数组的下标,而是值的范围大小
{
    if(l == r) return r;
    int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt;
    int mid = l + r >> 1;
    if( k <= cnt ) return query( tr[p].l , tr[q].l , l , mid , k );
    else return query( tr[p].r , tr[q].r , mid + 1 , r , k - cnt );
}

int main(void){
    scanf("%d%d",&n,&m);
    for(int i = 1 ; i <= n ; i ++ )
    {
        scanf( "%d" , &a[i] );
        num.push_back(a[i]);
    }
    
    sort(num.begin(),num.end());
    num.erase( unique(num.begin(),num.end()) , num.end() ) ;
    
    root[0] = build( 0 , num.size() - 1 );
    for(int i = 1 ; i <= n ; i ++ )
    {
        root[i] = insert(root[i - 1] , 0 , num.size() - 1 , find(a[i]));
    }
    
    while( m -- ){
        int l , r , k ;
        scanf("%d%d%d", &l , &r , &k );
        printf("%dn",num[query(root[l - 1], root[r] , 0 , num.size() - 1 , k )]);
    }
}

现在看我自己写的代码也有点看不懂呢。
不过不影响解决本题 , 本题的思路是,假如每个树(题目给了森林)的根节点为 u , 对于树上的每一个节点, 我们维护该节点到根节点u的路径。

如图所示,求的是8到根节点的路径,如果求u , v之间的路径的第k小值 , 我们就需要 考虑去求u , v(如果求u , v的路径第k小值) ,我们需要去求lca(u , v) 然后用
u r o o t 的 路 径 − v r o o t 的 路 径 − l c a ( u , v ) r o o t 的 路 径 − f a ( l c a ( u , v ) r o o t 的 路 径 u _root 的路径 - v_root的路径 - lca(u , v)_root的路径 - fa(lca(u ,v)_root的路径 ur​oot的路径−vr​oot的路径−lca(u,v)r​oot的路径−fa(lca(u,v)r​oot的路径
然后在主席树做二分求解。
如果左子树的数的数量大于k就在左子树找
反之在右子树找k - left_cnt小的数

最后是合并操作,启发式合并,小的合并到大的,合并操作的细节很多,看一下代码去理解合并操作(L操作)
最后给出我的代码
请忽视我为了debug加的注释(逃)

#include 
#include 
#include 
#include 
#include 
#include 

using namespace std ;

int read()
{
	int res = 0 , flag = 1 ;
	char c = getchar() ;
	while(!isdigit(c))
	{
		if(c == '-') flag = -1 ;
		c = getchar() ;
	}
	while(isdigit(c))
	{
		res = (res << 1) + (res << 3) + (c ^ 48) ;
		c = getchar() ;
	}
	return res * flag ;
}

const int N = 8e4 + 10 , M = 4 * N ;
const int Inf = 2e9 ;

struct Node
{
	int ls, rs ;
	int size ;
} tr[N * 600] ;
int tot ;
int h[N] , e[M] , ne[M] , idx ;
int f[N][20] ;
int depth[N] , p[N] ;
int maxr , root[N] , v[N] ;
int sz[N] ;

void add(int a , int b)
{
	e[idx] = b , ne[idx] = h[a] , h[a] = idx ++ ;
}

void pushup(int u)
{
	tr[u].size = tr[tr[u].ls].size + tr[tr[u].rs].size ;
}

int find(int x)
{
	if(p[x] != x) p[x] = find(p[x]) ;
	return p[x] ; 
}
void update(int old , int &New , int val , int l , int r )
{
	tr[New = ++ tot] = tr[old] ;
	tr[New].size ++ ;
	if(l == r)
		return ;	
	int mid = l + r >> 1 ;
	if(val <= mid) update(tr[old].ls , tr[New].ls , val , l , mid) ;
	else update(tr[old].rs , tr[New].rs , val , mid + 1 , r) ;
	// pushup(New) ;
}

void build(int &u ,int l , int r)
{
	u = ++ tot ;
	if(l == r) return ;
	int mid = l + r >> 1 ;
	build(tr[u].ls , l , mid) ;
	build(tr[u].rs , mid + 1 , r) ;
}

void dfs(int u , int fa)
{
	sz[u] = 1 ;
	update(root[fa] , root[u] , v[u] , 1 , maxr) ;
	if(fa) p[u] = fa ;
	else p[u] = u ;
	depth[u] = depth[fa] + 1 ;
	f[u][0] = fa ;
	for(int k = 1 ; k < 19 ; k ++) 
		f[u][k] = f[f[u][k - 1]][k - 1] ;
	for(int i = h[u] ; ~i ; i = ne[i])
	{
		int j = e[i] ;
		if(j == fa) continue ;
		dfs(j , u) ;
		sz[u] += sz[j] ;
	}
}

int lca(int x , int y)
{
	if(depth[x] < depth[y]) swap(x , y) ;
	for(int k = 18 ; k >= 0 ; k --)
		if(depth[f[x][k]] >= depth[y]) 
			x = f[x][k] ;
	if(x == y) return y ;
	for(int k = 18 ; k >= 0 ; k --)
		if(f[x][k] != f[y][k])
			x = f[x][k] , y = f[y][k] ;
	return f[x][0] ;
}

int query(int x , int y , int fa , int ffa , int l , int r , int k)
{
	int total = tr[x].size + tr[y].size - tr[fa].size - tr[ffa].size ;
	// cout << total << " " << k << " " << l << " " << r << endl ;
	if(k > total) return -1 ;
	if(l == r) return l ; 
	int mid = l + r >> 1 ;
	int left = tr[tr[x].ls].size + tr[tr[y].ls].size - tr[tr[fa].ls].size - tr[tr[ffa].ls].size ;
	if(left >= k) return query(tr[x].ls , tr[y].ls , tr[fa].ls , tr[ffa].ls , l , mid , k) ;
	else return query(tr[x].rs , tr[y].rs , tr[fa].rs , tr[ffa].rs , mid + 1 , r , k - left) ;
}
int main()
{
	int T = read() ;
	T = 1 ;
	while(T --)
	{
		memset(h , -1 , sizeof h) ;
		memset(f , 0 , sizeof f) ;
		idx = 0 , tot = 0 ;
		int n , m , t ;
		n = read() , m = read() , t = read() ;
		for(int i = 1 ; i <= n ; i ++) p[i] = i ;
		vector res ;
		res.push_back(-Inf) ;
		for(int i = 1 ; i <= n ; i ++) v[i] = read() , res.push_back(v[i]) ;
		sort(res.begin() , res.end()) ;
		res.erase(unique(res.begin() , res.end()) , res.end()) ;
		for(int i = 1 ; i <= n ; i ++) v[i] = lower_bound(res.begin() , res.end() , v[i]) - res.begin() ;
		maxr = res.size() - 1 ;
		for(int i = 1 ; i <= m ; i ++)
		{
			int a = read() , b = read() ; 
			add(a , b) , add(b , a) ;
		}
		build(root[0] , 1 , maxr) ;
		for(int i = 1 ; i <= n ; i ++)
			if(!f[i][0]) dfs(i , 0) ;
		int cnt = 0 ; 
		int last = 0 ;
		for(int i = 1 ; i <= t ; i ++)
		{
			// if(++ cnt == 15) return 0 ;
			char op[2] ; 
			scanf("%s" , op) ;
			if(*op == 'Q')
			{
				int x = read() , y = read() , k = read() ;
				x ^= last , y ^= last , k ^= last ; 
				// cout << "query : " << x << " " << y << " " << k << endl ;
				int fa = lca(x , y) ;
				// cout << k << endl ;
				// cout << fa << endl ;
				// cout << tr[root[x]].size << " " << tr[root[y]].size << endl ;
				printf("%dn" , last = res[query(root[x] , root[y] , root[fa] , root[f[fa][0]] , 1 , maxr , k)]) ;	
			}
			else
			{
				int x = read() ^ last , y = read() ^ last ;
				int px = find(x) , py = find(y) ;
				if(sz[px] > sz[py]) swap(x , y) , swap(px , py);
				sz[py] += sz[px] ;
				add(x , y) , add(y , x) ;
				// cout << idx << endl ;
				dfs(x , y) ;
			}
		}
	}
}

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/289765.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号