ABC 107 / ARC 101 D Median of Medianを解いた。

ABC 107 / ARC 101 D Median of Median
D - Median of Medians
python(pypy)で解いた。

解説はこれ
https://img.atcoder.jp/arc101/editorial.pdf

ほとんど解説どおりだけど、勉強のため自分の言葉で解説する。

解説

以下、\lceil x \rceil により、x 以上の最小の整数を表す。

x を長さ M の整数列 b の中央値とする。
すると、次の性質を満たす。

  • x 以上の b の元は \lceil \frac{M}{2} \rceil 以上
  • x はこの性質を満たすものの中で最大

つまり、x = \max \{ y \in \mathbb{Z} | \sharp \{i | b_i \geq y\} \geq \lceil \frac{M}{2} \rceil\}となる。

ここで、\sharp \{i | b_i \geq y\} y に関して、広義単調減少である。

従って、x は2分探索で求めることができる。

与えられた長さ N の整数列を A とする。

y を2分探索における基準値とする。(L+R)//2みたいなやつ。

 0 \leq l < r \leq Nに対して、 m_{l,r}A[l,r]の中央値とする。

すると、上記の性質から、y 以上の要素の個数が\lceil \frac{N(N-1)}{2\cdot 2} \rceil以上であるかにより、

2分探索すればよい。

A[l,r]のうち、x 以上の元を\lceil \frac{r-l}{2\cdot 2} \rceil個以上もつものが\lceil \frac{N(N-1)}{2\cdot 2} \rceil以上かを調べれば良い。

ここで、A の元のうち x 以上のものを 1 に、x より小さいものを -1 に置き換えれば、

A[l,r] のうち、\sharp (\text{1 の個数}) \geq \sharp (\text{-1 の個数})となるものが\lceil \frac{N(N-1)}{2\cdot 2} \rceil以上かを調べれば良い。

個数の関係は和が 0 以上であることと同値なので、

A[l,r]のうち、sum(A[l:r])\geq 0となるものが\lceil \frac{N(N-1)}{2\cdot 2} \rceil以上かを調べれば良い。

 0 \leq i \leq Nに対して、S_i = \sum_{j = 1}^i A[i]とする。
ただし、S[0]=0とする。

以上をまとめると、

 \sharp \{(l,r) \in \{0, 1, \cdots, N\}^2 |\  l < r \text{ and } S_l \leq S_r\} \geq \lceil \frac{N(N-1)}{2\cdot 2} \rceilを調べれば良い。

これは、転倒数という名前らしい。

今回はBITを使って求めた。
次の文献を参考にした。

BITはこれ
http://hos.ac/slides/20140319_bit.pdf
Binary indexed tree(BIT)をやった - kamojirobrothersのブログ


転倒数BITはこれ
BITで転倒数を求める - Qiita

プログラムはこれ。

def add(B,a,n):#リストに値を追加する関数
    x = a
    while x <= n:
        B[x] += 1
        x += x&(-x)

def sums(B,a):#a番目までの和
    x = a
    S = 0
    while x != 0:
        S += B[x]
        x -= x&(-x)
    return S


def invnumber(n, S):# #{(i,j)| i<j and S[i]<=S[j]}
    B = [0]*(n*2 + 1)
    invs = 0
    for i in range(n):
        s = S[i] + n #BITで扱えるようにするために、nを加算した
        invs += sums(B, s) #i<j
        add(B, s, n*2)
    return invs

N = int( input())
A = list( map( int, input().split()))
R = max(A)+1
L = 0
c = (N*(N+1)//2 + 1)//2
while R - L > 1:
    M = (R+L)//2
    S = [0]*(N+1)
    for i in range(1,N+1):
        if A[i-1] >= M:
            S[i] = S[i-1] + 1
        else:
            S[i] = S[i-1] - 1
    if invnumber(N+1,S) >= c:
        L = M
    else:
        R = M
print(L)