樹狀數組套權值線段樹

靜態區間 k 小值(POJ 2104 K-th Number) 的問題可以用 權值線段樹\(O(n\log n)\) 的時間複雜度內解決。

如果區間變成動態的呢?即,如果還要求支持一種操作:單點修改某一位上的值,又該怎麼辦呢?

例題 二逼平衡樹(樹套樹)
例題 ZOJ 2112 Dynamic Rankings

如果用 線段樹套平衡樹 中所論述的,用線段樹套平衡樹,即對於線段樹的每一個節點,對於其所表示的區間維護一個平衡樹,然後用二分來查找 \(k\) 小值。由於每次查詢操作都要覆蓋多個區間,即有多個節點,但是平衡樹並不能多個值一起查找,所以時間複雜度是 \(O(n\log^3 n)\),並不是最優的。

優化的思路是把二分答案的操作和查詢小於一個值的數的數量兩種操作結合起來,使用 線段樹套動態開點權值線段樹,由於所有線段樹的結構是相同的,可以在多棵樹上同時進行線段樹上二分。

在修改操作進行時,先在線段樹上從上往下跳到被修改的點,刪除所經過的點所指向的動態開點權值線段樹上的原來的值,然後插入新的值,要經過 \(O(\log n)\) 個線段樹上的節點,在動態開點權值線段樹上一次修改操作是 \(O(\log n)\) 的,所以修改操作的時間複雜度為 \(O(\log^2 n)\)

在查詢答案時,先取出該區間覆蓋在線段樹上的所有點,然後用類似於靜態區間 \(k\) 小值的方法,將這些點一起向左兒子或向右兒子跳。如果所有這些點左兒子存儲的值大於等於 \(k\),則往左跳,否則往右跳。由於最多隻能覆蓋 \(O(\log n)\) 個節點,所以最多一次只有這麼多個節點向下跳,時間複雜度為 \(O(\log^2 n)\)

由於線段樹的常數較大,在實現中往往使用常數更小且更方便處理前綴和的 樹狀數組 實現。另外空間複雜度是 \(O(n\log^2 n)\) 的,使用時 注意空間限制

給出一種代碼實現:

實現
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <map>
#include <set>
#define LC o << 1
#define RC o << 1 | 1
using namespace std;
const int maxn = 1000010;
int n, m, a[maxn], u[maxn], x[maxn], l[maxn], r[maxn], k[maxn], cur, cur1, cur2,
    q1[maxn], q2[maxn], v[maxn];
char op[maxn];
set<int> ST;
map<int, int> mp;

struct segment_tree  // 封裝的動態開點權值線段樹
{
  int cur, rt[maxn * 4], sum[maxn * 60], lc[maxn * 60], rc[maxn * 60];

  void build(int& o) { o = ++cur; }

  void print(int o, int l, int r) {
    if (!o) return;
    if (l == r && sum[o]) printf("%d ", l);
    int mid = (l + r) >> 1;
    print(lc[o], l, mid);
    print(rc[o], mid + 1, r);
  }

  void update(int& o, int l, int r, int x, int v) {
    if (!o) o = ++cur;
    sum[o] += v;
    if (l == r) return;
    int mid = (l + r) >> 1;
    if (x <= mid)
      update(lc[o], l, mid, x, v);
    else
      update(rc[o], mid + 1, r, x, v);
  }
} st;

// 樹狀數組實現
int lowbit(int o) { return (o & (-o)); }

void upd(int o, int x, int v) {
  for (; o <= n; o += lowbit(o)) st.update(st.rt[o], 1, n, x, v);
}

void gtv(int o, int* A, int& p) {
  p = 0;
  for (; o; o -= lowbit(o)) A[++p] = st.rt[o];
}

int qry(int l, int r, int k) {
  if (l == r) return l;
  int mid = (l + r) >> 1, siz = 0;
  for (int i = 1; i <= cur1; i++) siz += st.sum[st.lc[q1[i]]];
  for (int i = 1; i <= cur2; i++) siz -= st.sum[st.lc[q2[i]]];
  // printf("j %d %d %d %d\n",cur1,cur2,siz,k);
  if (siz >= k) {
    for (int i = 1; i <= cur1; i++) q1[i] = st.lc[q1[i]];
    for (int i = 1; i <= cur2; i++) q2[i] = st.lc[q2[i]];
    return qry(l, mid, k);
  } else {
    for (int i = 1; i <= cur1; i++) q1[i] = st.rc[q1[i]];
    for (int i = 1; i <= cur2; i++) q2[i] = st.rc[q2[i]];
    return qry(mid + 1, r, k - siz);
  }
}

/* 線段樹實現
void build(int o,int l,int r)
{
    st.build(st.rt[o]);
    if(l==r)return;
    int mid=(l+r)>>1;
    build(LC,l,mid);
    build(RC,mid+1,r);
}
void print(int o,int l,int r)
{
    printf("%d %d:",l,r);
    st.print(st.rt[o],1,n);
    printf("\n");
    if(l==r)return;
    int mid=(l+r)>>1;
    print(LC,l,mid);
    print(RC,mid+1,r);
}
void update(int o,int l,int r,int q,int x,int v)
{
    st.update(st.rt[o],1,n,x,v);
    if(l==r)return;
    int mid=(l+r)>>1;
    if(q<=mid)update(LC,l,mid,q,x,v);
    else update(RC,mid+1,r,q,x,v);
}
void getval(int o,int l,int r,int ql,int qr)
{
    if(l>qr||r<ql)return;
    if(ql<=l&&r<=qr){q[++cur]=st.rt[o];return;}
    int mid=(l+r)>>1;
    getval(LC,l,mid,ql,qr);
    getval(RC,mid+1,r,ql,qr);
}
int query(int l,int r,int k)
{
    if(l==r)return l;
    int mid=(l+r)>>1,siz=0;
    for(int i=1;i<=cur;i++)siz+=st.sum[st.lc[q[i]]];
    if(siz>=k)
    {
        for(int i=1;i<=cur;i++)q[i]=st.lc[q[i]];
        return query(l,mid,k);
    }
    else
    {
        for(int i=1;i<=cur;i++)q[i]=st.rc[q[i]];
        return query(mid+1,r,k-siz);
    }
}
*/

int main() {
  scanf("%d%d", &n, &m);
  for (int i = 1; i <= n; i++) scanf("%d", a + i), ST.insert(a[i]);
  for (int i = 1; i <= m; i++) {
    scanf(" %c", op + i);
    if (op[i] == 'C')
      scanf("%d%d", u + i, x + i), ST.insert(x[i]);
    else
      scanf("%d%d%d", l + i, r + i, k + i);
  }
  for (set<int>::iterator it = ST.begin(); it != ST.end(); it++)
    mp[*it] = ++cur, v[cur] = *it;
  for (int i = 1; i <= n; i++) a[i] = mp[a[i]];
  for (int i = 1; i <= m; i++)
    if (op[i] == 'C') x[i] = mp[x[i]];
  n += m;
  // build(1,1,n);
  for (int i = 1; i <= n; i++) upd(i, a[i], 1);
  // print(1,1,n);
  for (int i = 1; i <= m; i++) {
    if (op[i] == 'C') {
      upd(u[i], a[u[i]], -1);
      upd(u[i], x[i], 1);
      a[u[i]] = x[i];
    } else {
      gtv(r[i], q1, cur1);
      gtv(l[i] - 1, q2, cur2);
      printf("%d\n", v[qry(1, n, k[i])]);
    }
  }
  return 0;
}