[hdoj4616] Game

Game

题意

给一颗树,已知每个点有权值和陷阱,你不能往回走,如果走过$c$个陷阱或者无路可走就结束,可以从任一点开始走,问能获得多少权值。

题解

树形dp。

$dp[i][j][0/1]$是在以$i$为终点,经过$j$个陷阱,$1$代表起点有陷阱,$0$代表起点没有陷阱。
转移方程

  • $dp[u][j][0]=max(dp[u][j][0],dp[v][j][0]+val[v])$
  • $dp[u][j][1]=max(dp[u][j][1],dp[v][j][1]+val[v])$ 其中$j>0$因为起点有陷阱。

答案就是枚举每个根节点中子树进入和出去,特判一下链的组成

  • 起点为trap和起点不为trap组成,路线是从trap出发到另一个起点。
  • 起点都不是trap,那么这时候就要求$j_1+j_2<c$,因为路线走到一半就可能到了trap为$c$卡主,不能走了。
  • 两个起点都为trap,那么无所谓只需要$j_1+j_2 \leq c$就可以了。

心路历程:我一开始想的是二维dp转移方程,但是只能对一个根节点求值,因为子节点也有可能走根节点那条路径,没有想到树形dp对于求解问题这么灵活,答案不一定一定是dp数组中的元素,而可以是通过拼接数组中元素来构成答案。而且我想法是从根节点走到叶子节点,而不是从子树节点出发到根节点,还是too young啊。之后我拼接了自己的垃圾二维dp,发现因为在迷宫中你碰到$c$个陷阱之后就不能走了,但是我这个二维数组不能转移方程。我两条链拼接的时候会多加几个节点,因为左右两个链如果加起来为$c$之后,其中一条链到了trap点,就不能再走了。所以才要三维数组。

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for(int i=(a);i<(n);i++)
#define per(i,a,n) for(int i=(n-1);i>=(a);i--)
#define fi first
#define se second
typedef pair <int,int> pII;
typedef long long ll;
const int INF = 0x3f3f3f3f;
//head
const int maxn = 5e4 + 10;
struct node{
int next,to;
}G[maxn<<1];
int head[maxn],trap[maxn],val[maxn],dp[maxn][4][2];
int cnt,T,n,c,ans;
void add(int u,int v){
G[cnt].to = v;
G[cnt].next = head[u];
head[u] = cnt++;
}
void init(){
cnt = 0;
memset(head,-1,sizeof(int)*(n+2));
memset(trap,0,sizeof(int)*(n+2));
memset(dp,0,sizeof(dp));
ans = 0;
}
void dfs1(int u,int fa){
dp[u][trap[u]][trap[u]] = val[u];
for(int i=head[u];~i;i=G[i].next){
int v = G[i].to;
if(fa == v) continue;
dfs1(v,u);
for(int j=0;j<=c;j++){
for(int k=0;k+j<=c;k++){
ans = max(ans,dp[u][j][1]+dp[v][k][1]);
if(j) ans = max(ans,dp[u][j][1]+dp[v][k][0]);
if(k) ans = max(ans,dp[u][j][0]+dp[v][k][1]);
if(j+k<c) ans = max(ans,dp[u][j][0]+dp[v][k][0]);
}
}
for(int j=0;j+trap[u]<=c;j++){
dp[u][j+trap[u]][0] = max(dp[u][j+trap[u]][0],dp[v][j][0]+val[u]);
if(j) dp[u][j+trap[u]][1] = max(dp[u][j+trap[u]][1],dp[v][j][1]+val[u]);
}
}
}
int main(){
#ifdef LOCAL
freopen("1.in","r",stdin);
#endif
scanf("%d",&T);
while(T--){
scanf("%d%d",&n,&c);
init();
rep(i,0,n) scanf("%d%d",&val[i],&trap[i]);
rep(i,0,n-1){
int u,v;scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
dfs1(0,-1);
printf("%d\n",ans);
}
return 0;
}