#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
typedef long long ll;
int h[maxn],e[maxn<<1],ne[maxn<<1],idx;
void add(int x,int y)
{
e[idx]=y,ne[idx]=h[x],h[x]=idx++;
}
ll a[maxn],f[maxn],g[maxn];
int t[maxn];
void dfs(int u,int fa)
{
g[u]=a[u];
ll mx1=-1e16-9,mx2=-1e16-9;
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa)continue;
dfs(v,u);
g[u]+=f[v]-a[v];
ll tmp=a[v]+g[v]-f[v];
if(tmp>=mx1)mx2=mx1,mx1=tmp;
else if(tmp>mx2)mx2=tmp;
}
f[u]=g[u];
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa)continue;
f[u]=max(f[u],g[u]+a[v]);
if(t[v]==3)
{
if(g[v]-f[v]+a[v]==mx1)f[u]=max(f[u],g[u]+mx2+a[v]);
else f[u]=max(f[u],g[u]+mx1+a[v]);
}
}
}
signed main()
{
ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int T;cin>>T;
while(T--)
{
int n;cin>>n;
for(int i=1;i<=n;i++)h[i]=-1;
idx=0;
for(int i=1;i<=n;i++){cin>>a[i];f[i]=g[i]=0;}
for(int i=1;i<=n;i++)cin>>t[i];
for(int i=1;i<n;i++){int x,y;cin>>x>>y;add(x,y),add(y,x);}
dfs(1,0);
printf("%lld\n",f[1]);
// cout<<f[1]<<endl;
}
}
#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5+10;
typedef long long ll;
typedef pair<ll,ll> pll;
int w[maxn];
int e[maxn<<1],ne[maxn<<1],h[maxn],idx;
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int n;
ll res;
int dfs(int u,int fa)
{
int mx1=0,mx2=0;
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa)continue;
int t=dfs(v,u);
if(t>=mx1)mx2=mx1,mx1=t;
else if(t>mx2)mx2=t;
}
if(fa){res+=max(0,w[u]-mx1);if(w[u]>mx1)mx1=w[u];}
else res+=w[u]-mx1+w[u]-mx2;
return mx1;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%d",&w[i]);
memset(h,-1,sizeof(h));
for(int i=1;i<n;i++)
{
int a,b;scanf("%d%d",&a,&b);
add(a,b);add(b,a);
}
int root=0;
for(int i=1;i<=n;i++)if(w[i]>w[root])root=i;
int t=dfs(root,0);
printf("%lld\n",res);
}
dp[u][sum]=dp[u][i]*dp[u][j];
i+j和为sum
因为存在负值,所以需要加一个col_size,取答案的时候再减去就好了
Size每次dfs都算一遍是为了降低时间复杂度,不过对于树上背包的时间复杂度还是感觉挺困惑的
#include<bits/stdc++.h>
using namespace std;
const int maxn=3e3+10;
typedef long long ll;
ll dp[maxn][maxn<<1];
const int mod=998244353;
int siz[maxn],col[maxn],n,col_siz[maxn];
int h[maxn],e[maxn<<1],ne[maxn<<1],idx,id,temp[maxn<<1];
void add(int a,int b){e[idx]=b,ne[idx]=h[a],h[a]=idx++;}
ll res;
void dfs(int u,int fa)
{
siz[u]=1;
if(col[u]==id)
{
dp[u][1+col_siz[id]]=1;
}
else dp[u][col_siz[id]-1]=1;
for(int t=h[u];~t;t=ne[t])
{
int v=e[t];if(v==fa)continue;
dfs(v,u);
for(int i=0;i<=col_siz[id]*2;i++)temp[i]=0;
int m1=min(siz[u],col_siz[id]),m2=min(siz[v],col_siz[id]);
for(int i=-m1;i<=m1;i++)
for(int j=-m2;j<=m2&&i+j<=col_siz[id];j++)
if(i+j+col_siz[id]>=0)
{
temp[i+j+col_siz[id]]=(temp[i+j+col_siz[id]]+dp[u][i+col_siz[id]]*dp[v][j+col_siz[id]]%mod)%mod;
}
for(int i=0;i<=col_siz[id]*2;i++)
dp[u][i]=(dp[u][i]+temp[i])%mod;
siz[u]+=siz[v];
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%d",&col[i]),col_siz[col[i]]++;
memset(h,-1,sizeof(h));
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
for(int i=1;i<=n;i++)
if(col_siz[i])
{
id=i;
for(int j=1;j<=n;j++)for(int k=0;k<=col_siz[i]*2;k++)dp[j][k]=0;
dfs(1,0);
for(int j=1;j<=n;j++)for(int k=col_siz[i]+1;k<=2*col_siz[i];k++)res=(res+dp[j][k])%mod;
}
printf("%lld\n",res);
}
[CERC2014]Outer space invaders
#include<bits/stdc++.h>
using namespace std;
const int inf=0x3f3f3f3f;
const int BufferSize=1<<16;
char buffer[BufferSize],*head,*tail;
inline char Getchar() {
if(head==tail) {
int l=fread(buffer,1,BufferSize,stdin);
tail=(head=buffer)+l;
}
return *head++;
}
inline int read() {
int x=0,f=1;char c=Getchar();
for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1;
for(;isdigit(c);c=Getchar()) x=x*10+c-'0';
return x*f;
}
void print(int x)
{
if(x>9) print(x/10);
putchar(x%10|'0');
}
const int maxn=310,maxm=1e4+10;
struct Node
{
int a,b,d;
}point[maxn];
int n,cnt,mp[maxm],dp[maxn<<1][maxn<<1];
void init()
{
n=read(); cnt=0; memset(dp,0,sizeof(dp));memset(mp,0,sizeof(mp));
for(int i=1;i<=n;i++)
{
point[i].a=read(),point[i].b=read(),point[i].d=read();
mp[point[i].a]=1;mp[point[i].b]=1;
}
for(int i=1;i<maxm;i++)
if(mp[i])mp[i]=++cnt;
for(int i=1;i<=n;i++)point[i].a=mp[point[i].a],point[i].b=mp[point[i].b];
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
init();
for(int len=1;len<=cnt;len++)
for(int l=1;l+len-1<=cnt;l++)
{
int r=l+len-1;
int id=-1;
for(int i=1;i<=n;i++)if(point[i].a>=l&&point[i].b<=r&&(id==-1||point[i].d>point[id].d))id=i;
if(id==-1)continue;
dp[l][r]=inf;
for(int k=point[id].a;k<=point[id].b;k++)
dp[l][r]=min(dp[l][r],dp[l][k-1]+point[id].d+dp[k+1][r]);
}
printf("%d\n",dp[1][cnt]);
}
}
经典区间DP了
#include <bits/stdc++.h>
using namespace std;
const int maxn=5e3+10;
int dp[maxn][maxn],lt[maxn],u[maxn],a[maxn];
const int BufferSize=1<<16;
char buffer[BufferSize],*head,*tail;
inline char Getchar() {
if(head==tail) {
int l=fread(buffer,1,BufferSize,stdin);
tail=(head=buffer)+l;
}
return *head++;
}
inline int read() {
int x=0,f=1;char c=Getchar();
for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1;
for(;isdigit(c);c=Getchar()) x=x*10+c-'0';
return x*f;
}
void print(int x)
{
if(x>9) print(x/10);
putchar(x%10|'0');
}
int main()
{
int T;T=read();
while(T--)
{
int n;n=read();
for(int i=1;i<=n;i++)a[i]=read();
memset(u,0,sizeof(u));
for(int i=1;i<=n;i++)
{
lt[i]=u[a[i]];
u[a[i]]=i;
dp[i][i]=0;
}
for(int len=2;len<=n;len++)
for(int l=1;l+len-1<=n;l++)
{
int r=l+len-1;
dp[l][r]=dp[l][r-1]+1;
for(int k=lt[r];k>=l;k=lt[k])
{
dp[l][r]=min(dp[l][r],dp[l][k]+dp[k+1][r]);
}
}
printf("%d\n",dp[1][n]);
}
}