ABC117 D - XXOR
概要
N個の非負整数列A と非負整数Kが与えられる。
X <= K なる整数Xに対し、 f(X) = (X xor A[0]) + (X xor A[1]) + ... + (X xor A[N-1]) という関数を定める。
f(X)の最大値を求めよ。
制約
1 <= X <= 10^5
0 <= K <= 10^12
0 <= A[i] <= 10^12
共通方針
X, K, Aを2進数で考え、各桁(ビット)ごとに見ていく。
全てのA[j]のi桁目を確認し、0が多ければXのi桁目は1となるのがよく、逆に1が多ければXのi桁目は0となるのが良い。
「XとKのビットを上から見ていくとき、初めて両者のビットが異なる桁」をi桁目とすると、XはK以下であるため下記のように構成される。
①i桁目より上位のビットでは、XとKのビットは完全に一致しなければならない
②i桁目ではXのビットが0、Kのビットが1である
③i桁目より下位のビットでは、XのビットはKのビットにかかわらず何でもよい。
下記を参考にしました。
drken1215.hatenablog.com
方針①
「XとKのビットを上から見ていくとき、初めて両者のビットが異なる桁」で全探索する。
解答①
Submission #4187639 - AtCoder Beginner Contest 117
#include <bits/stdc++.h> #define FOR(i,a,b) for(int i=(a);i<(b);i++) #define rep(i,n) FOR(i,0,n) #define RFOR(i,a,b) for(int i=(a)-1;i>=(b);i--) #define rrep(i,n) RFOR(i,n,0) using namespace std; typedef long long ll; typedef unsigned long long ull; int main() { cin.tie(0); ios::sync_with_stdio(false); ll N, K; cin >> N >> K; ll A[N]; rep(i, N) cin >> A[i]; // 各ビットごとの1の個数を調べる ll b[61] = {}; for(int i = 60; i >= 0; i--){ ll y = 1LL << i; int cnt = 0; rep(i, N){ if(A[i] & y) cnt++; } b[i] = cnt; } // for(int i = 60; i >= 0; i--) cout << b[i]; // cout << endl; // Xを探す // XとKとで初めてビットが異なる箇所で場合分け // X <= K なので、初めてビットが異なればその箇所ではXのビットは0、Kのビットは1 // それ以降は何でもよいので貪欲 ll ans = 0; for(int i = 60; i >= -1; i--){ ll t = 0; if(i != -1 && !(K & (1LL << i))) continue; for(int j = 60; j >= 0; j--){ ll y = 1LL << j; if(j > i){ if(K & y) t += (N - b[j]) * y; else t += b[j] * y; }else if(j == i){ t += b[j] * y; // (N - b[j]) * y; }else{ t += max(b[j] * y, (N - b[j]) * y); } } // cout << i << " " << t << endl; ans = max(ans, t); } cout << ans << endl; }
方針②
桁DPする。
dp[i][j] : 桁DP。上からi桁目までのところを使ったときの最大値。j = 1のとき、XはKより真に小さい。
桁DPの考え方は下記を参考にしました。
drken1215.hatenablog.com
解答②
Submission #4187678 - AtCoder Beginner Contest 117
#include <bits/stdc++.h> #define FOR(i,a,b) for(int i=(a);i<(b);i++) #define rep(i,n) FOR(i,0,n) #define RFOR(i,a,b) for(int i=(a)-1;i>=(b);i--) #define rrep(i,n) RFOR(i,n,0) using namespace std; typedef long long ll; typedef unsigned long long ull; int main() { cin.tie(0); ios::sync_with_stdio(false); ll N, K; cin >> N >> K; ll A[N]; rep(i, N) cin >> A[i]; // 各ビットごとの1の個数を調べる ll b[61] = {}; for(int i = 60; i >= 0; i--){ ll y = 1LL << i; int cnt = 0; rep(i, N){ if(A[i] & y) cnt++; } b[i] = cnt; } int MAX_DIGITS = 50; ll dp[MAX_DIGITS+1][2]; // dp[i][j] : 桁DP。上からi桁目までのところを使ったときの最大値 // j = 1のとき、XはKより真に小さい memset(dp, -1, sizeof(dp)); dp[0][0] = 0; for(int i = 0; i < MAX_DIGITS; i++){ ll y = 1LL << (MAX_DIGITS - i - 1); ll point0 = b[MAX_DIGITS - i - 1] * y; ll point1 = (N - b[MAX_DIGITS - i - 1]) * y; if(dp[i][1] != -1){ // 1 -> 1 : なんでもよい // cout << "1 " << i << " " << dp[i][1] << " " << (dp[i][1] + max(point0, point1)) << endl; dp[i+1][1] = max(dp[i+1][1], dp[i][1] + max(point0, point1)); } if(dp[i][0] != -1){ // 0 -> 1 : Kのi桁目は1, Xのi桁目は0 // if(K & y) cout << "2 " << i << " " << dp[i][0] << " " << (dp[i][0] + point0) << endl; if(K & y) dp[i+1][1] = max(dp[i+1][1], dp[i][0] + point0); } if(dp[i][0] != -1){ // 0 -> 0 : Kそのまま // cout << "3 " << i << " " << dp[i][0] << " " << (dp[i][0] + ((K & y) ? point1 : point0)) << endl; dp[i+1][0] = max(dp[i+1][0], dp[i][0] + ((K & y) ? point1 : point0)); } } cout << max(dp[MAX_DIGITS][0], dp[MAX_DIGITS][1]) << endl; }