链接: http://codeforces.com/contest/935/problem/E
题意: 现在给你一个字符串保证里边所有的数字都在0到9 范围内,并且每一个?对应着一对括号,一共有n+m 个?,并且其中的n为+ m个为 - ,你要合理的安排 + - 使得式子的值最大。
思路: 每个问号对应着一对括号,那么我们就可以把整个字符串看成一棵二叉树,每个叶子节点为 0 到 9 每个非叶子节点为 ? ,那么就转化成一个 树形dp 的问题了。dp[ i ] [ j ][2] 分别表示节点 i 的加号或者减号为j 的最小最大值。因为这里的+ - 只保证一个<= 100,一开始就给,,。。。 搞错了。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int > pii;
const int inf =10000000;
struct node
{
int ls,rs;
}tr[10005];
char s[10005];
int id[10005];
ll dp[10005][105][2];
map<pii,int >mp;
stack<int >st;
stack<int >ss;
int n,m;
int len;
int tot;
int op;
void init()
{
tot=0;
for(int i=1;i<=len;i++)
{
if(s[i]>='0'&&s[i]<='9') continue;
else if(s[i]=='?'){
id[i]=++tot;
ss.push(i);
}
else if(s[i]=='(')
{
st.push(i);
}
else if(s[i]==')'){
int l=st.top();
st.pop();
mp[pii(l,i)]=ss.top();
ss.pop();
}
}
int cnt=0;
for(int i=1;i<=len;i++){
if(s[i]==')'||s[i]=='(') continue;
cnt++;
}
for(int i=0;i<=cnt+2;i++) tr[i].ls=tr[i].rs=-1;
for(int i=0;i<=cnt+3;i++){
for(int j=0;j<=101;j++){
dp[i][j][0]=inf;
dp[i][j][1]=-inf;
}
}
}
int tree(int l,int r)
{
if(l==r){
++tot;
dp[tot][0][0]=s[l]-'0';
dp[tot][0][1]=s[l]-'0';
return tot;
}
int pos=mp[pii(l,r)];
int idd=id[pos];
int ls=tree(l+1,pos-1);
int rs=tree(pos+1,r-1);
tr[idd].ls=ls;
tr[idd].rs=rs;
return idd;
}
void dfs(int id)
{
if(tr[id].ls==-1) return ;
int ls=tr[id].ls;
int rs=tr[id].rs;
dfs(ls); dfs(rs);
for(int i=0;i<=100;i++){
for(int j=0;j<=100;j++){
if(i+j>100) break;
if(op==0)
{
dp[id][i+j][0]=min(dp[id][i+j][0],dp[ls][i][0]-dp[rs][j][1]);
dp[id][i+j][1]=max(dp[id][i+j][1],dp[ls][i][1]-dp[rs][j][0]);
if(i+j+1>n) continue;
dp[id][i+j+1][0]=min(dp[id][i+j+1][0],dp[ls][i][0]+dp[rs][j][0]);
dp[id][i+j+1][1]=max(dp[id][i+j+1][1],dp[ls][i][1]+dp[rs][j][1]);
}
else{
dp[id][i+j][0]=min(dp[id][i+j][0],dp[ls][i][0]+dp[rs][j][0]);
dp[id][i+j][1]=max(dp[id][i+j][1],dp[ls][i][1]+dp[rs][j][1]);
if(i+j+1>n) continue;
dp[id][i+j+1][0]=min(dp[id][i+j+1][0],dp[ls][i][0]-dp[rs][j][1]);
dp[id][i+j+1][1]=max(dp[id][i+j+1][1],dp[ls][i][1]-dp[rs][j][0]);
}
}
}
}
int main()
{
scanf("%s",s+1);
scanf("%d %d",&n,&m);
if(n>100) op=1;
len=strlen(s+1);
init();
int rt=tree(1,len);
dfs(rt);
ll ans;
if(op==0) ans=dp[rt][n][1];
else ans=dp[rt][m][1];
printf("%lld\n",ans);
return 0;
}
/*
((1?(5?7))?((6?2)?7))
2 3
*/