【数据结构】树套树——线段树套平衡树

0
8

线段树套平衡树

可能后面还会写其他类型的树套树,这里写的是线段树套平衡树。

这里看一下某谷的模板题https://www.luogu.com.cn/problem/P3380

其实作为蒟蒻,我是没有看出来树套树的,在大佬们的帮助下,我理解了思路。

其实特别好做(才怪,调了一晚上,码量很难受)

基本上就是splay和线段树的模版

对于整个的理解,就是线段数上的每个区间,都有一棵与之对应的平衡树,通过线段树找到相对应的区间后,在splay上进行操作。

其中需要改动和注意以下几个地方

1.root的存储

因为对应的是区间,所以splay中的root,现在改用一个数组来放每一个区间的根。即root[i]表示线段数上的第i个区间的splay树的根。

2.查询区间内第k大的数权值

这个函数我理解了很久,说明一下,

这个我们需要用到二分来实现,我们不能讲询问区间拆成两个区间,因为合并不了答案啊。所以我们依靠二分来实现。

了解这个函数要先细品seg_rank和Splay_rank这两个函数,这里求的ans只是区间内比k小的数的总个数,所以当ans=k-1的时候,其实已经求得答案。

但是二分中,判断条件是ans<k,因此会继续二分。最后l==r时得到的是第k位数+1。强烈建议手推几组数据,因为我也说不清楚。淦~~~

3.最后一点,码量极大(对我这种蒟蒻来说),调试的时候注意眼睛和心态

其余就没什么特别注意的,代码有注释

//这里打的是线段树上加平衡树
//即每一个线段树的节点上都有一棵splay树 
#include<cstdio>
#include<iostream>
#include<cctype>
#include<cstring>
#include<algorithm>
using namespace std;

inline int read(){
    int s=0;bool flag=true;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')flag=false;ch=getchar();}
    while(isdigit(ch)){s=(s<<3)+(s<<1)+ch-'0';ch=getchar();}
    return flag?s:-s;
}

inline void print_ans(int x){
    if(x<0)putchar('-'),x=-x;
    if(x>9)print_ans(x/10);
    putchar(x%10+'0');
}
inline void print(int x){print_ans(x),puts("");}

#define Max(a,b) a=max(a,b) 
#define Min(a,b) a=min(a,b)
const int N=4e6+5;
const int inf=2147483647;
int MAX,n,m,ans,w[N];

//----------上面是基础数据,输入优化等乱七八糟的东西 

int tot,root[N]; 
struct Splay_tree{int son[2],val,size,cnt,father;}T[N];
#define ls(x)     T[x].son[0]
#define rs(x)     T[x].son[1]
#define val(x)     T[x].val
#define sze(x)     T[x].size
#define cnt(x)     T[x].cnt
#define fa(x)    T[x].father
//不为别的,看小括号比看中括号爽

inline void Splay_update(int x){
    sze(x)=(ls(x)?sze(ls(x)):0)+(rs(x)?sze(rs(x)):0)+cnt(x);
}

inline void Splay_clear(int x){ls(x)=rs(x)=fa(x)=cnt(x)=sze(x)=val(x)=0;}

inline void Splay_rotate(int x){
    int y=fa(x),z=fa(y);
    int jud_x=(rs(y)==x),jud_y=(rs(z)==y);
    int w=T[x].son[jud_x^1];
    T[y].son[jud_x]=w;if(w)    fa(w)=y;
    if(z)T[z].son[jud_y]=x;fa(x)=z;
    T[x].son[jud_x^1]=y;fa(y)=x;
    Splay_update(y),Splay_update(x);
}

inline void Splay_splay(int id,int x,int goal=0){
    while(fa(x)!=goal){
        int y=fa(x),z=fa(y);
        if(z!=goal)(ls(z)==y)^(ls(y)==x)?Splay_rotate(x):Splay_rotate(y);
        Splay_rotate(x);
    }
    if(!goal)    root[id]=x;
}

inline int Splay_find(int id,int x){
    int u=root[id];
    while(x){
        if(val(u)==x){Splay_splay(id,u);return u;}
        u=T[u].son[x>val(u)];
    }    
    return 0;
}

inline void Splay_insert(int id,int x){
    int u=root[id];
    if(!root[id]){
        root[id]=u=++tot;
        ls(u)=rs(u)=0,val(u)=x,sze(u)=cnt(u)=1,fa(u)=0;
        return ;
    }
    int pre=0;
    while(true){
        if(val(u)==x){cnt(u)++;Splay_update(pre);break;}
        pre=u,u=T[u].son[x>val(u)];
        if(!u){
            u=++tot;
            T[pre].son[x>val(pre)]=u;
            ls(u)=rs(u)=0,val(u)=x,sze(u)=cnt(u)=1,fa(u)=pre;
            Splay_update(pre);break;
        }
    }
    Splay_splay(id,u);
}

inline int Splay_rank(int id,int k){
    //以线段树上的id节点为根的splay树上寻找权值k的排名 
    int x=root[id],sum=0;
    while(x){
        if(val(x)==k)    return sum+(ls(x) ? sze(ls(x)) : 0);
        else if(val(x)<k){
            sum+=(ls(x) ? sze(ls(x)) : 0)+cnt(x);
            x=rs(x);
        }
        else    x=ls(x);
    }
    return sum;
}

inline int Splay_Getpre(int id,int x){
    int u=root[id];
    while(u){
        if(val(u)<x){Max(ans,val(u));u=rs(u);}
        else u=ls(u);
    }
    return ans;
}

inline int Splay_Getsuf(int id,int x){
    int u=root[id];
    while(u){
        if(val(u)>x){Min(ans,val(u));u=ls(u);}
        else u=rs(u);
    }
    return ans;
}

inline int Splay_pre(int id){int x=ls(root[id]);while(rs(x))x=rs(x);return x;}//根节点的前驱,用于delete操作

inline void Splay_delete(int id,int x){
    int u=Splay_find(id,x);
    if(cnt(u)>1){cnt(u)--;Splay_update(u);return ;}
    if(!ls(u) && !rs(u)){Splay_clear(root[id]);root[id]=0;return ;}
    if(!ls(u)){root[id]=rs(u),fa(rs(u))=0;return ;}
    if(!rs(u)){root[id]=ls(u),fa(ls(u))=0;return ;}
    int Pre=Splay_pre(id),Father=root[id];
    Splay_splay(id,Pre,0);
    rs(root[id])=rs(Father);
    fa(rs(Father))=root[id];
    Splay_clear(Father),Splay_update(root[id]);
} 

//----------上面是关于splay的函数,下面是关于segment tree的函数 

#define lc     (id<<1)
#define rc     (id<<1|1)
#define mid    ((l+r)>>1)

inline void seg_insert(int id,int l,int r,int pos,int k){
    Splay_insert(id,k);
    if(l==r)    return ;
    if(pos<=mid)    seg_insert(lc,l,mid,pos,k);
    else    seg_insert(rc,mid+1,r,pos,k);
}

inline void seg_rank(int id,int l,int r,int L,int R,int k){
    //在整棵线段树上找到该splay的区间
    if(l==L&&r==R){ans+=Splay_rank(id,k);return ;}
    if(R<=mid)    seg_rank(lc,l,mid,L,R,k);
    else if(L>mid) seg_rank(rc,mid+1,r,L,R,k);
    else seg_rank(lc,l,mid,L,mid,k),seg_rank(rc,mid+1,r,mid+1,R,k);
}

inline void seg_modify(int id,int l,int r,int pos,int k){
    //单点修改权值,splay同时更新
    Splay_delete(id,w[pos]);Splay_insert(id,k);
    if(l==r){w[pos]=k;return ;}
    if(pos<=mid)    seg_modify(lc,l,mid,pos,k);
    else    seg_modify(rc,mid+1,r,pos,k); 
}

inline void seg_pre(int id,int l,int r,int L,int R,int k){
    //查询权值为k的前驱
    if(l==L&&r==R){Max(ans,Splay_Getpre(id,k));return ;}
    if(R<=mid) seg_pre(lc,l,mid,L,R,k);
    else if(L>mid)    seg_pre(rc,mid+1,r,L,R,k);
    else seg_pre(lc,l,mid,L,mid,k),seg_pre(rc,mid+1,r,mid+1,R,k);
}

inline void seg_suf(int id,int l,int r,int L,int R,int k){
    //查询权值为k的后驱
    if(l==L&&r==R){Min(ans,Splay_Getsuf(id,k));return ;}
    if(R<=mid) seg_suf(lc,l,mid,L,R,k);
    else if(L>mid)    seg_suf(rc,mid+1,r,L,R,k);
    else seg_suf(lc,l,mid,L,mid,k),seg_suf(rc,mid+1,r,mid+1,R,k);
}

//----------其他情况的函数

inline int search_value(int L,int R,int kth){
    //寻求L,R区间第k大的数,即区间第k大
    int l=0,r=MAX+1,m;
    while(l<r){
        m=(l+r)>>1;ans=0;
        seg_rank(1,1,n,L,R,m);
        if(ans<kth)    l=m+1;
        else r=m;
    }
    return l-1;
}

signed main(){
    n=read(),m=read();MAX=-inf;
    for(int i=1;i<=n;i++){
        w[i]=read();
        seg_insert(1,1,n,i,w[i]);
        Max(MAX,w[i]);
    }
    while(m--){
        int opt=read();
        switch(opt){
            case 1:{//区间查k值的排名 
                int L=read(),R=read(),k=read();ans=0;
                seg_rank(1,1,n,L,R,k);
                print(ans+1);
                break;
            }
            case 2:{//区间查第k大 
                int L=read(),R=read(),kth=read();
                print(search_value(L,R,kth));
                break;
            }
            case 3:{//单点修改 
                int pos=read(),k=read();
                seg_modify(1,1,n,pos,k);
                break;
            }
            case 4:{//k的前驱 
                int L=read(),R=read(),k=read();ans=-inf;
                seg_pre(1,1,n,L,R,k);
                print(ans);
                break;
            }
            case 5:{//k的后驱 
                int L=read(),R=read(),k=read();ans=inf;
                seg_suf(1,1,n,L,R,k);
                print(ans);
                break;
            }
        }
    }
    return 0;
}

特别鸣谢这篇题解及其作者

https://www.luogu.com.cn/blog/Qiu/qian-tan-shu-tao-shu-xian-duan-shu-tao-ping-heng-shu-post

<

发布回复

请输入评论!
请输入你的名字