mini notes

競技プログラミングの解法メモを残していきます。

ABC117 D - XXOR

D - XXOR

概要

N個の非負整数列A と非負整数Kが与えられる。
X <= K なる整数Xに対し、 f(X) = (X xor A[0]) + (X xor A[1]) + ... + (X xor A[N-1]) という関数を定める。
f(X)の最大値を求めよ。

制約

1 <= X <= 10^5
0 <= K <= 10^12
0 <= A[i] <= 10^12

共通方針

X, K, Aを2進数で考え、各桁(ビット)ごとに見ていく。
全てのA[j]のi桁目を確認し、0が多ければXのi桁目は1となるのがよく、逆に1が多ければXのi桁目は0となるのが良い。

「XとKのビットを上から見ていくとき、初めて両者のビットが異なる桁」をi桁目とすると、XはK以下であるため下記のように構成される。
①i桁目より上位のビットでは、XとKのビットは完全に一致しなければならない
②i桁目ではXのビットが0、Kのビットが1である
③i桁目より下位のビットでは、XのビットはKのビットにかかわらず何でもよい。

下記を参考にしました。
drken1215.hatenablog.com

方針①

「XとKのビットを上から見ていくとき、初めて両者のビットが異なる桁」で全探索する。

解答①

Submission #4187639 - AtCoder Beginner Contest 117

#include <bits/stdc++.h>
 
#define FOR(i,a,b) for(int i=(a);i<(b);i++)
#define rep(i,n) FOR(i,0,n)
#define RFOR(i,a,b) for(int i=(a)-1;i>=(b);i--)
#define rrep(i,n) RFOR(i,n,0)
 
using namespace std;
 
typedef long long ll;
typedef unsigned long long ull;
 
int main()
{
	cin.tie(0);
	ios::sync_with_stdio(false);
 
	ll N, K;
	cin >> N >> K;
 
	ll A[N];
	rep(i, N) cin >> A[i];
 
	// 各ビットごとの1の個数を調べる
	ll b[61] = {};
	for(int i = 60; i >= 0; i--){
		ll y = 1LL << i;
		int cnt = 0;
		rep(i, N){
			if(A[i] & y) cnt++;
		}
		b[i] = cnt;
	}
 
	// for(int i = 60; i >= 0; i--) cout << b[i];
	// cout << endl;
 
	// Xを探す
	// XとKとで初めてビットが異なる箇所で場合分け
	// X <= K なので、初めてビットが異なればその箇所ではXのビットは0、Kのビットは1
	// それ以降は何でもよいので貪欲
	ll ans = 0;
	for(int i = 60; i >= -1; i--){
		ll t = 0;
		if(i != -1 && !(K & (1LL << i))) continue;
		for(int j = 60; j >= 0; j--){
			ll y = 1LL << j;
			if(j > i){
				if(K & y) t += (N - b[j]) * y;
				else t += b[j] * y;
			}else if(j == i){
				t += b[j] * y; // (N - b[j]) * y;
			}else{
				t += max(b[j] * y, (N - b[j]) * y);
			}
		}
		// cout << i << " " << t << endl;
		ans = max(ans, t);
	}
 
	cout << ans << endl;
}

方針②

桁DPする。
dp[i][j] : 桁DP。上からi桁目までのところを使ったときの最大値。j = 1のとき、XはKより真に小さい。
桁DPの考え方は下記を参考にしました。
drken1215.hatenablog.com

解答②

Submission #4187678 - AtCoder Beginner Contest 117

#include <bits/stdc++.h>
 
#define FOR(i,a,b) for(int i=(a);i<(b);i++)
#define rep(i,n) FOR(i,0,n)
#define RFOR(i,a,b) for(int i=(a)-1;i>=(b);i--)
#define rrep(i,n) RFOR(i,n,0)
 
using namespace std;
 
typedef long long ll;
typedef unsigned long long ull;
 
int main()
{
	cin.tie(0);
	ios::sync_with_stdio(false);
 
	ll N, K;
	cin >> N >> K;
 
	ll A[N];
	rep(i, N) cin >> A[i];
 
	// 各ビットごとの1の個数を調べる
	ll b[61] = {};
	for(int i = 60; i >= 0; i--){
		ll y = 1LL << i;
		int cnt = 0;
		rep(i, N){
			if(A[i] & y) cnt++;
		}
		b[i] = cnt;
	}
 
	int MAX_DIGITS = 50;
	ll dp[MAX_DIGITS+1][2];
	// dp[i][j] : 桁DP。上からi桁目までのところを使ったときの最大値
	//            j = 1のとき、XはKより真に小さい
	memset(dp, -1, sizeof(dp));
	dp[0][0] = 0;
 
	for(int i = 0; i < MAX_DIGITS; i++){
		ll y = 1LL << (MAX_DIGITS - i - 1);
		ll point0 = b[MAX_DIGITS - i - 1] * y;
		ll point1 = (N - b[MAX_DIGITS - i - 1]) * y;
 
		if(dp[i][1] != -1){
			// 1 -> 1 : なんでもよい
			// cout << "1 " << i << " " << dp[i][1] << " " << (dp[i][1] + max(point0, point1)) <<  endl;
			dp[i+1][1] = max(dp[i+1][1], dp[i][1] + max(point0, point1));
		}
 
		if(dp[i][0] != -1){
			// 0 -> 1 : Kのi桁目は1, Xのi桁目は0
			// if(K & y) cout << "2 " << i << " " << dp[i][0] << " " << (dp[i][0] + point0) <<  endl;
			if(K & y) dp[i+1][1] = max(dp[i+1][1], dp[i][0] + point0);
		}
 
		if(dp[i][0] != -1){
			// 0 -> 0 : Kそのまま
			// cout << "3 " << i << " " << dp[i][0] << " " << (dp[i][0] + ((K & y) ? point1 : point0)) <<  endl;
			dp[i+1][0] = max(dp[i+1][0], dp[i][0] + ((K & y) ? point1 : point0));
		}
	}
 
	cout << max(dp[MAX_DIGITS][0], dp[MAX_DIGITS][1]) << endl;
}