主要算法:BFS或DP
随便试试,可以发现,是哪些数字所组成的xor-sum其实无所谓,主要是组成xor-sum的个数是关键,所以可以用集合的思想
两个xor-sum集合
S
S
S 和
T
T
T 中元素做XOR
运算=
S
S
S 和
T
T
T 的对称差=
S
Δ
T
S\Delta T
SΔT =
(
S
\
T
)
∪
(
T
\
S
)
=
(
S
∪
T
)
\
(
S
∪
T
)
(S\backslash T)\cup(T\backslash S)=(S\cup T)\backslash (S\cup T)
(S\T)∪(T\S)=(S∪T)\(S∪T)
![](https://i-blog.csdnimg.cn/blog_migrate/1aa1ed34a1ad83bc831078b5fbbe0fc3.png)
考虑用f[i]
表示获得大小为i
个元素的xor-sum集合
S
i
S_i
Si ,所需要的最少次数,初始化f[k]=0
转移:令一个新的集合
T
T
T 为当前我们要做的查询集合,则
T
T
T 的大小一定为
k
k
k,集合
T
∩
S
i
T\cap Si
T∩Si 中的元素个数记为x
,则集合
T
T
T 和
S
i
S_i
Si 的对称差
T
Δ
S
i
T\Delta S_i
TΔSi 中的元素个数即为i+k-2*x
,故
f
[
i
+
k
−
2
x
]
=
f
[
i
]
+
1
f[i+k-2x]=f[i]+1
f[i+k−2x]=f[i]+1
本题还要求f[n]
的最小值,所以可以使用BFS寻找最优解,也可dp寻解
由于每个f[i]
只会取最小值,故这里可以使用BFS可以做到直接从最小值延拓,直接更新新的节点,复杂度
O
(
N
2
)
O(N^2)
O(N2),而使用dp则是
O
(
N
3
)
O(N^3)
O(N3)
剩下就是记录每次的 f[i]
是从谁延拓过来的,最后再从 f[n]
倒退回 f[k]
即可,可以用两个set表示现在已有的集合和重复集合,这里方法很多样
每次延拓时要注意:
1.i+k-x<=n
(
T
T
T 和
S
i
S_i
Si 并中的元素个数不能超出n)
2.x<=i
(重复元素个数不能大于原有集合的大小,也就是
T
∩
S
i
⊆
S
i
T\cap S_i\subseteq S_i
T∩Si⊆Si )
#include <bits/stdc++.h>
#define DB double
#define LL long long
//#define int LL
#define Case(t) cout << "Case #" << t << ": "
using namespace std;
const int N = 5e2 + 10;
const int INF = 0x3f3f3f3f;
int n, k;
int prt[N], cover[N], dis[N];
queue<int> q;
set<int> st, pt;
void init() {
cin >> n >> k;
}
int ans = 0;
void ask() {
cout << "? ";
for (auto i : pt) {
cout << i << ' ';
}
cout << '\n';
cout.flush();
int x;
cin >> x;
ans ^= x;
}
void solve() {
memset(dis, 0x3f, sizeof(dis));
dis[k] = 0;
q.push(k);
while (!q.empty()) {
int t = q.front(); q.pop();
for (int i = 0; i < k; i++) {
if (t+k-2*i > 0 && t+k-i <= n && i <= t) {
int j = t+k-2*i;
if (dis[j] > dis[t] + 1) {
dis[j] = dis[t] + 1;
prt[j] = t;
cover[j] = i;
q.push(j);
}
}
}
}
if (dis[n] == INF) {
cout << "-1" << '\n';
return;
}
for (int i = 1; i <= n; i++) st.insert(i);
for (int i = n; i; i = prt[i]) {
int x = cover[i];
int nw = k - x;
pt.clear();
for (set<int>::iterator it = st.begin(); it != st.end(); ) {
pt.insert(*it);
st.erase(it++);
if (pt.size() == nw) break;
}
for (int j = 1; j <= n; j++) {
if (pt.size() == k) break;
if (st.count(j) == 0 && pt.count(j) == 0) {
st.insert(j);
pt.insert(j);
}
}
ask();
}
cout << "! " << ans << '\n';
cout.flush();
}
signed main(){
init();
solve();
return 0;
}