読者です 読者をやめる 読者になる 読者になる

noyのブログ

プログラミングとかゲームとか

Binary Indexed Tree(BIT)を学ぶ(1)

BITとはなんぞや

参考:蟻本159頁

列a_1, a_2, ... , a_nがある。

  • iが与えられたとき、a_1からa_iまでの和を求める。
  • (iとjが与えられたとき、a_iからa_jまでの和を求める。)
  • iとxが与えられたとき、 a_i += x する。


要はある区間の和をO(log n)で求めることができるデータ構造。

POJ1990 MooFest

問題概要は省略。

解法

i番目の牛とj番目の牛の座標の差 * max(i番目の牛の聴力, j番目の牛の聴力)
全てのi、jの組み合わせについて上記の式を計算し、総和を求める。

・愚直解
二重ループで全てのi, jの組み合わせを計算する。O(n^2)。
n <= 20000 なのでTLE。

・BITを使った解法

  1. 聴力の値で入力をソートする。これでどちらの聴力が大きいかを考える必要がなくなる。
  2. i番目の牛に注目し、座標の差の合計を求める。
  3. ans += 聴力 * 座標の差の合計
  4. BITに値を加算する
  5. i++、 2に戻る。

聴力を降順にし順番に処理することで、maxを考えなくてよくなる。

聴力の合計は、以下のようにして求めている。
距離の差の総和 = 区間0~x−1にいる牛の数 * x - 区間0~x−1にいる牛の座標の総和 + 区間x~MAX_Xにいる牛の座標の総和 - 区間x~MAX_Xにいる牛の数 * x

コード

先駆者様のコードをほぼ写経。

<省略>
#define range(i,a,b) for(int i = (a); i < (b); i++)
#define rep(i,b) for(int i = 0; i < (b); i++)
#define all(a) (a).begin(), (a).end()
#define show(x)  cerr << #x << " = " << (x) << endl;
#define debug(x) cerr << #x << " = " << (x) << " (L" << __LINE__ << ")" << " " << __FILE__ << endl;
const int INF = 100000000;
using namespace std;

const int MAX_N = 20010;
const int MAX_X = 20010;

pair<int, int> pr[MAX_N];
int N;
long long dists[MAX_X], cnts[MAX_X]; //BIT

long long sum(int i, long long bit[MAX_X]){
    int s = 0;
    while(i > 0){
        s += bit[i];
        i -= i & -i;
    }
    return s;
}

long long sum(int first, int last, long long bit[MAX_X]){ //first-last間の和
    return sum(last, bit) - sum(first, bit);
}

void add(int i, int x, long long bit[MAX_X]){
    while(i <= MAX_X){
        bit[i] += x;
        i += i & -i;
    }
}


long long solve(){
    sort(pr, pr+ N); //1. 聴力の値でソート
    long long ans = 0;
    rep(i,N){
        int v = pr[i].first, x = pr[i].second;
        long long c1 = sum(0, x - 1, cnts), c2 = sum(x,MAX_X - 1, cnts);

        //2. 座標の差の総和を求める 3. ans += 聴力 * 座標の差の合計
        ans += v * ( (c1 * x - sum(0, x - 1, dists)) + (sum(x, MAX_X - 1, dists) - c2 * x) );
 
        //4. BITに値を加算する
        add(x, 1, cnts); 
        add(x, x, dists);
    }
    return ans;
}

int main(){
    cin >> N;
    rep(i,N){
        int x, v;
        cin >> v >> x;
        pr[i] = make_pair(v, x);
    }
    printf("%lld\n", solve());
    return 0;
}

long long sum(int i, long long bit[MAX_X]){
long long sum(int first, int last, long long bit[MAX_X]){
void add(int i, int x, long long bit[MAX_X]){
上記の関数はBITの実装です。