复杂状态的动态规划

本文为该书的笔记:刘汝佳. 算法竞赛入门经典.第2版[M]. 清华大学出版社, 2014.

最优配对问题

空间里有 n 个点 P_0, P_1, … , P_{n-1} ,你的任务是把它们配成 n/2 对( n 是偶数),使得每个点恰好在一个点对中。所有点对中两点的距离之和应尽量小。 n ≤ 20, |xi|,|yi|,|zi|≤ 10000

  1. 每个点都需要配对
  2. 把问题看做多阶段决策问题:先确定 P_0 和谁配对,接下来是 P_1 ……最后是 P_{n-1}
  3. d(i) 表示前 i 个点的最小距离和
  4. 考虑第 i 点的决策,假设点 i 和点 j 配对( j<i
  5. 那么接下来的问题应是“把前 i-1 个点中除了 j 之外的其他点两两配对”
  6. “前 i-1 个点中除了 j 之外的其他点”显然无法用任何一个 d 值来刻画
  7. 状态无法转移,所以状态和状态的指标函数需要改变

  1. 针对状态无法转移,常见的方法是增加维度,更细致地描述状态。
  2. 既然之前提到“除了某种元素之外”,不妨把元素当做状态的一部分,设 d(i,S) 表示前 i 个点中,位于集合 S 中的元素两两配对的最小距离和,则状态转移方程为:
d(i,S)=min \left \{|P_iP_j|+d(i-1,S-\{i\}-\{j\})|j \in S \right\}
  1. 如何表示子集 S ,因为它要作为数组 d 中的第二维下标,所以要用整数表示集合({0,1,2,...,n-1}的任意子集)。 在《算法竞赛入门经典(第二版)》第七章第二节讲述了子集生成的算法:
    • 增量构造法
    • 位向量法
    • 二进制法(从右往左第 i 位表示元素 i 是否在集合 S 中)
  2. 其实, S 中最大的元素就是 i ,所以状态中 i 不需要保存,可以直接用 d(S) 表示“把 S 中两元素两两配对的最小距离和”。则状态转移方程为:
d(S)=min \left \{|P_iP_j|+d(i-1,S-\{i\}-\{j\})|j \in S ,i=max\{S\}\right\}

状态有 2^n (即每个元素都有是否存在于子集中两种状态)个,每个状态有 O(n) 种转移方式,所以总时间复杂度为 O(n2^n)


对于另外一种状态转移方式:

d(S)=min \left \{|P_iP_j|+d(i-1,S-\{i\}-\{j\})|i,j \in S \right\}

对于该种方式,每个状态有 O(n^2) 种转移方式,所以总时间复杂度为 O(n^22^n) 。是因为即使匹配成了 n/2 对,匹配的顺序有 (n/2)! 种。其实本题与结果的排序无关。
S 中的最大元素:使用循环判断,平均判断次数为 2
原因:
该题目等价于在{1,2,...,n}判断每个子集中最小的元素。
最小元素为 k 的子集有 $$2^{n-k}$$ 个,而其需要的判断次数为 k ,所以总的循环次数为: $$s=\sum_{k=1}^{n} k*2^{n-k}$$ 即如下列表格所示

2^{n-1}
2^{n-2}2^{n-2}
............
2^1............
11...111

这是一个 n*n 的表格,我们首先对表格的每列相加,第 u 列的和为 \sum_{k=0}^{u} 2^k,由等比数列求和公式可得第 u 列的和为 2^{u}-1,再使用等比数列求和公式将每一列的和相加可得 2^{n+1}-n-2 ,状态有 2^n 个,所以平均判断次数仅为 2
完整程序(书中给的是递推计算,本文给出记忆化搜索程序):

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>

using namespace std;
const int maxn = 20 + 2;
const int INF = 1 << 30;
int n;
int data[maxn][3];
double d[1 << 20 + 2];
int vis[1 << 20 + 2];
int S;
double dist(int a, int b)
{
    double s1 = (double)(data[a][0] - data[b][0]) * (data[a][0] - data[b][0]);
    double s2 = (double)(data[a][1] - data[b][1]) * (data[a][1] - data[b][1]);
    double s3 = (double)(data[a][2] - data[b][2]) * (data[a][2] - data[b][2]);
    double s4 = sqrt(s1 + s2 + s3);
//    cout << a << " " << b << " " << s4 << endl;
    return s4;
    ;
}
double dp(int s)
{
    //当子集 $s$ 为空的时候直接返回 $0$ ,否则返回值会为 $INF * 1.0$ ,或者可以在初期设置 $vis[0]=1,d[0]=0
    if (s == 0)
    {
        return 0;
    }
    double &ans = d[s];
    if (vis[s])
    {
        return ans;
    }
    int i;
    for (i = n - 1; i >= 0; i--)
    {
        if (s & (1 << i))
            break;
    }
    ans = INF * 1.0;
    for (int j = 0; j < i; j++)
    {
        if (s & (1 << j))
            ans = min(ans, dist(i, j) + dp(s ^ (1 << i) ^ (1 << j)));
    }
    vis[s] = 1;
//    cout << hex << s << " " << ans << endl;
    return ans;
}
int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    while (cin >> n && n)
    {
        memset(data, -1, sizeof(data));
        memset(vis, 0, sizeof(vis));
        for (int i = 0; i < n; i++)
        {
            cin >> data[i][0] >> data[i][1] >> data[i][2];
        }
        S = (1 << n) - 1;
        cout << dp(S) << endl;
    }
    return 0;
}
复制代码

样例:

输入:
18
1 1 1
4 7 8
-1 -9 -7
2 3 1
1 4 7
3 6 9
1 -3 2
6 7 5
2 5 8
2 3 6
4 5 2
7 8 5
4 5 1
-1 2 3
0 0 0
-100 0 4
9 5 1
7 5 3
输出:
119.058
复制代码

货郎担问题

样例(来自:章斌斌 链接: https://www.jianshu.com/p/30ba1d66c729 ):

样例输入:
10

42 160 34 136 134 94 78 18 196

42 166 66 106 87 11 122 195 32

160 166 4 98 198 3 154 75 121

34 66 4 187 112 52 94 36 144

136 106 98 187 12 64 45 46 48

134 87 198 112 12 154 109 196 131

94 11 3 52 64 154 11 79 80

78 122 154 94 45 109 11 13 86

18 195 75 36 46 196 79 13 162

196 32 121 144 48 131 80 86 162

样例输出:
284
复制代码

思考:

  1. 每个点都需要连接
  2. 把问题看做多阶段决策问题:先确定 P0 和谁连接,接下来是 P1 ……最后是 P_{n-1}
  3. d(i) 表示前 i 个点的经过且只经过一次的最短道路总长度
  4. 考虑第 i 点的决策,假设点 i 和点 j 连接( j<i
  5. 那么接下来的问题应是“从 j 点出发到 i 点经过所有其他点且只经过一次的最短路径”
  6. 无法用之前定义的状态描述该转移
  7. 重新设置状态 d(i,S) 表示从子集 Si 点出发到第 0 点(经过所有点)的最短路径
  8. 则状态转移方程为:
d(i,S)=min \left \{|P_iP_j|+d(j,S-\{j\})|i,0\not\in S,j \in S \right \}
  1. 最终结果为 d(0,S-{0}) ,此处的 S 指的是全集
  2. 边界为 d(i,\varnothing) ,此时返回为 P_iP_0

完整程序:

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>

using namespace std;
const int maxn = 15 + 2;
const int INF = 1 << 30;
int n;
int data[maxn][maxn];
int d[maxn][1 << 15 + 2];
int S;
int dp(int i, int S)
{
    if (!S)
    {
        return data[i][0];
    }
    int &ans = d[i][S];
    if (ans != -1)
    {
        return ans;
    }
    ans = INF;
    for (int j = 0; j < n; j++)
    {
        if (S & (1 << j))
        {
            ans = min(ans, data[i][j] + dp(j, S ^ (1 << j)));
        }
    }
    return ans;
}
int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    while (cin >> n && n)
    {
        memset(d, -1, sizeof(d));
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < n; j++)
            {
                if (i == j)
                    continue;
                cin >> data[i][j];
            }
        }
        S = ((1 << n) - 1) ^ (1 << 0);
        cout << dp(0, S) << endl;
    }
    return 0;
}
复制代码

图的色数

给一个无向图 G ,把图中的结点染成尽量少的颜色,使得相邻结点颜色不同。

样例:

输入:
5
1 3
2 5
5 3
4 2
1 2
输出:
2
复制代码

本题容易落入惯性思维:先确定 P_0 和谁配对,接下来是 P_1 ……最后是 P_{n-1}
状态: d(S) 表示把结点集合 S 染色所需要的颜色数的最小值。 状态转移方程:

d(S)=min \left \{d(S-S')+1 \right \}

S'是 S 的子集,并且内部没有边。 首先通过预处理保存每个结点集是否可以染成同一种颜色(即“内部没有边”)(也可以将每个边连接起来的两个点组成一个子集,判断这些子集是否存在结点集的子集,如果存在,则内部有边,如果不存在,则内部没有边),则算法的主要时间取决于“高效枚举一个集合的所有子集”。
那么,如何判断一个集合是另一个集合的子集?如要判断 A_0 是否为 A 的子集,则用二进制表示之后,判断 (A0|A)==A 或者 (A0|A)==A0 是否成立,若成立,则 A_0A 的子集。
完整程序:

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <bitset>

using namespace std;
const int maxn = 15 + 2;
const int INF = 1 << 30;
int n;
int d[(1 << 15) + 2];
int a, b;
vector<int> edges;
int es; //边的数量
int note_edges_inside(int S)
{
    for (int i = 0; i < es; i++)
    {
        if ((edges[i] | S) == S)
        {
            return 0;
        }
    }
    return 1;
}
int dp(int S)
{
    if (!S)
        return 0;
    int &ans = d[S];
    if (ans != -1)
    {
        return ans;
    }
    ans = INF;
    for (int S0 = S; S0; S0 = (S0 - 1) & S)
    {
        if (note_edges_inside(S0))
        {
            ans = min(ans, dp(S ^ S0) + 1);
        }
    }
    //    cout<<bitset<sizeof(int)*8>(S)<<" "<<ans<<endl;
    return ans;
}
int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    int S;
    while (cin >> n && n)
    {
        memset(d, -1, sizeof(d));
        edges.clear();
        while (cin >> a >> b)
        {
            a--;
            b--;
            edges.push_back((1 << a) | (1 << b));
        }
        es = edges.size();
        S = (1 << n) - 1;
        cout << dp(S) << endl;
    }
    return 0;
}
复制代码

为什么 for (int S0 = S; S0; S0 = (S0 - 1) & S) 可以高效枚举 S 的所有子集?

  1. S 的真子集一定小于 S ,所以我们从大到小枚举,即从 S-1 开始枚举
  2. 若想从集合 A 中找出集合 S 的数值最大的子集,取 AS 的交集即可,即 A & B
  3. 所以 S 的数值最大的真子集为 S0 = (S - 1) & S
  4. S的次一个数值最大的真子集即为S0以下的最大的真子集,所以从S0-1$ 开始枚举。
  5. 所以 S 的数值第二大的真子集为 S0 = (S0 - 1) & S
  6. 跳回 4 继续循环。

上述算法的时间复杂度的计算:

  1. 状态数为全集的所有子集的个数,状态的转移数为该子集的子集个数,每个状态的状态转移数是不同的,所以算法的时间复杂度为全集的所有子集的“子集个数”之和。
  2. c(S) 为集合 S 的子集的个数(它等于 2^{|S|}(|S|为集合 S 的元素个数))
  3. 则时间复杂度为 sum{c(S0)|S0 是全集的子集}
  4. 元素个数相同的集合,子集个数也相同,可以按照元素个数“合并同类项”
  5. 元素个数为 k 的集合有 C(n,k) 个,其中每个集合有 2^k 个子集
  6. 因此本题的时间复杂度为 sum{C(n,k)2^k}=sum{C(n,k)2^k1^{n-k}}=(2+1)^n=3^n
  7. 所以枚举 1~n 的每个集合 S 的所有子集的总时间复杂度为 $$O(3^n)$

校长的烦恼(UVa10817

某校有 m 个教师和 n 个求职者,需讲授 s 个课程( 1s81m201n100 )。已知每人的工资 c10000c50000 )和能教的课程集合,要求支付最少的工资使得每门课都至少有两名教师能教。在职教师不能辞退。

样例:

Sample Input
2 2 2
10000 1
20000 2
30000 1 2
40000 1 2
0 0 0
Sample Output
60000
复制代码

如果将题目改为“每门课都至少有一名教师能教”,则可以首先算出 m 个教师用完之后剩余的课程集合,则用 d(A,B) 表示求职者集合 A 完成课程集合 B 的最小花费。
则状态转移方程为:

d(A,B)=min \left \{ w(i)+d(A-\{i\},B-B_i)|i\in A \right \}

但是由于集合中元素的唯一性的要求,每门课都至少有两名教师,破坏了唯一性,无法直接用集合来表示课程。
其次,该方式毫无必要的增添了“录取顺序”,使用该方式的时间复杂度为 O(n^2) ,忽略“录取顺序”之后的时间复杂度应为 2n ,因而增加了时间复杂度。
为了解决集合中元素唯一性的问题,我们用两个集合, s1 表示恰好有一个人教的科目集合, s2 表示至少有两个人教的科目集合, d(i,s1,s2) 表示前 i 个人时候的最小花费。状态转移方程为 d(i,s1,s1)=min(d(i+1,s1',s2')+c[i],d(i+1,s1,s2)),其中第一项表示聘用,第二项表示不聘用。当 i>=m 的时候状态转移方程才出现第二项。 完整程序:

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <bitset>
#include <sstream>

using namespace std;
const int maxn = 120 + 2;
const int INF = 1 << 30;
int s, m, n;
int c[maxn];  //工资
int st[maxn]; //课程
int d[maxn][1 << 8 + 2][1 << 8 + 2];
int dp(int i, int s0, int s1, int s2)
{
    if (i == m + n)
        return s2 == (1 << n) - 1 ? 0 : INF;
    int &ans = d[i][s1][s2];
    if (ans != -1)
    {
        return ans;
    }
    ans = INF;
    if (i >= m)
        ans = dp(i + 1, s0, s1, s2);
    int m0 = s0 & st[i]; //取 $i$ 所能讲授的课程和暂时没有人能够讲授的课程的交集,即 $i$ 所能改变的 $s0$ 的元素,这些元素都去了 $s1
    int m1 = s1 & st[i]; //i$ 所能减去的 $s0$ 的元素
    s0 ^= m0;            //差集
    s1 = (s1 ^ m1) | m0;
    s2 |= m1;
    ans = min(ans, c[i] + dp(i + 1, s0, s1, s2));
    //    cout<<i<<" "<<bitset<sizeof(int)*8>(s0)<<" "<<bitset<sizeof(int)*8>(s1)<<" "<<bitset<sizeof(int)*8>(s2)<<" " <<ans<<endl;
    return ans;
}
int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    int x;
    string line;
    while (getline(cin, line))
    {
        memset(st, 0, sizeof(st));
        memset(d, -1, sizeof(d));
        stringstream ss(line);
        ss >> s >> m >> n;
        if (s == 0)
            break;
        for (int i = 0; i < m + n; i++)
        {
            getline(cin, line);
            stringstream ss(line);
            ss >> c[i];
            while (ss >> x)
            {
                x--;
                st[i] |= (1 << x);
            }
        }
        cout << dp(0, (1 << n) - 1, 0, 0) << endl;
    }
    return 0;
}
复制代码

使用三进制数保存状态的方法:

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <bitset>
#include <sstream>

using namespace std;
const int maxn = 120 + 2;
const int INF = 1 << 30;
int s, m, n;
int c[maxn];       //工资
int st[maxn];      //课程
int d[maxn][6561]; //3^8=6561
int an(int i, int j)
{
    //    cout << " "<<i<<" "<<j;
    int t = 0;
    for (int k = 0; k < s; k++)
    {
        if (i % (int)pow(3, k + 1) / (int)pow(3, k) < 2)
        {
            t += (i % (int)pow(3, k + 1) / (int)pow(3, k) + j % (int)pow(2, k + 1) / (int)pow(2, k)) * (int)pow(3, k);
        }
        else
        {
            t += i % (int)pow(3, k + 1) / (int)pow(3, k) * (int)pow(3, k);
        }
    }
    //    cout <<" "<<t<<endl;
    return t;
}
int dp(int i, int s)
{
    if (i == m + n)
    {
        //        cout<<i<<" "<<s<<endl;
        return s == (int)pow(3, n) - 1 ? 0 : INF;
    }

    int &ans = d[i][s];
    if (ans != -1)
    {
        return ans;
    }
    ans = INF;
    if (i >= m)
    {
        ans = dp(i + 1, s);
    }
    ans = min(ans, c[i] + dp(i + 1, an(s, st[i])));
    //    cout<<i<<" "<<s<<" " <<ans<<endl;
    return ans;
}
int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    int x;
    string line;
    while (getline(cin, line))
    {
        memset(st, 0, sizeof(st));
        memset(d, -1, sizeof(d));
        stringstream ss(line);
        ss >> s >> m >> n;
        if (s == 0)
            break;
        for (int i = 0; i < m + n; i++)
        {
            getline(cin, line);
            stringstream ss(line);
            ss >> c[i];
            while (ss >> x)
            {
                x--;
                st[i] |= (1 << x);
            }
        }
        cout << dp(0, 0) << endl;
    }
    return 0;
}
复制代码

20 个问题

Sample Input
8 1
11010101
11 4
00111001100
01001101011
01010000011
01100110001
11 16
01000101111
01011000000
01011111001
01101101001
01110010111
01110100111
10000001010
10010001000
10010110100
10100010100
10101010110
10110100010
11001010011
11011001001
11111000111
11111011101
11 12
10000000000
01000000000
00100000000
00010000000
00001000000
00000100000
00000010000
00000001000
00000000100
00000000010
00000000001
00000000000
9 32
001000000
000100000
000010000
000001000
000000100
000000010
000000001
000000000
011000000
010100000
010010000
010001000
010000100
010000010
010000001
010000000
101000000
100100000
100010000
100001000
100000100
100000010
100000001
100000000
111000000
110100000
110010000
110001000
110000100
110000010
110000001
110000000
0 0
Sample Output
0
2
4
11
9
复制代码

设“心里想的物体“为 W 。用集合 s 表示已经询问的特征集,用集合 a 来表示”以确认物体 W 具备的特征集“,则 a 一定是 s 的子集。
d(s,a) 表示已经问了特征集 s ,其中确认 W 具有的特征集为 a 时,还需要询问的最小次数(即继续询问完这最小次数以后,所有给定的物体都能够被完全区分开,即任何一个物体都能够被确定下来)。如果下一次询问的对象是特征 k ,则询问次数为:

max\{d(s+\{k\},a+\{k\}),d(s+\{k\},a)\}+1

因为需要“保证能够猜到”,所以取 max 。 状态转移方程为:

d(s,a)=min \left \{ max\{d(s+\{k\},a+\{k\}),d(s+\{k\},a)\}+1,k \not \in s \right \}

边界条件为:如果只有一个物体满足“具备集合 a 中的所有特征,但不具备集合 s-a 中的所有特征”这一条件,则 d(s,a)=0 ,因为无须进一步询问,问了 s 中的所有特征就可以确实是或者不是这个物体。如果只有两个物体满足“具备集合 a 中的所有特征,但不具备集合 s-a 中的所有特征”这一条件,则 d(s,a)=1 ,因为知道是哪两个物体,还知道已经问了哪些特征,只要再问一个不在 s 中但是两个物体在该特征方面有不同的特征,就一定能知道是哪个物体。
完整程序:

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <bitset>
#include <sstream>

using namespace std;
const int maxn = 128 + 2;
const int maxm = 11 + 1;
const int INF = 1 << 30;
int m, n; //m$ 个特征, $n$ 个物体
char object[maxn][maxm];
int cnt[1 << 11 + 2][1 << 11 + 2];
int d[1 << 11 + 2][1 << 11 + 2];
/** \brief 已经问了特征集 $s$ ,其中具备的特征为 $a$ ,(区分满足这些条件的物体)还需要询问的最小次数
 *
 * \param 已经问了特征集
 * \param 其中具备的特征
 * \return 还需要询问的最小次数
 *
 */

int dp(int s, int a)
{
    //边界条件:已经问了特征集 $s$ ,其中具备的特征为 $a$ 的物体为零个、一个或两个的时候
    if (cnt[s][a] <= 1)
        return 0;
    if (cnt[s][a] == 2)
        return 1;
    int &ans = d[s][a];
    if (ans != -1)
        return ans;
    ans = INF;
    for (int k = 0; k < m; k++)
    {
        //如果 $k$ 不是 $s$ 的元素
        if (!(s & (1 << k)))
        {
            ans = min(ans, max(dp(s | (1 << k), a | (1 << k)), dp(s | (1 << k), a)) + 1);
        }
    }
    return ans;
}
/** \brief 统计“问了特征集 $s$ ,所具备的特征集为 $a$ ”的物体个数
 *
 * \param
 * \param
 * \return
 *
 */
void init()
{
    memset(cnt, 0, sizeof(cnt));
    for (int i = 0; i < n; i++)
    {
        //首先用二进制保存该物体的特征集
        int feature = 0;
        for (int j = 0; j < m; j++)
        {
            if (object[i][j] == '1')
            {
                feature |= (1 << j); //反向保存特征的效果是一样的
            }
        }
        //遍历所有特征集合获取统计,该物体对统计量的影响
        for (int s = 0; s < (1 << m); s++)
        {
            cnt[s][s & feature]++;
        }
    }
}
int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    while (cin >> m >> n && m && n)
    {
        memset(d, -1, sizeof(d));
        for (int i = 0; i < n; i++)
        {
            cin >> object[i];
        }
        init();
        cout << dp(0, 0) << endl;
    }
    return 0;
}
复制代码

基金管理(UVa1412)

样例:

Sample Input
144624.00 9 5 3
IBM 500 3
97.27 98.31 97.42 98.9 100.07 98.89 98.65 99.34 100.82
GOOG 100 1
467.59 483.26 487.19 483.58 485.5 489.46 499.72 505 504.28
JAVA 1000 2
5.54 5.69 5.6 5.65 5.73 6 6.14 6.06 6.06
MSFT 250 1
29.86 29.81 29.64 29.93 29.96 29.66 30.7 31.21 31.16
ORCL 300 3
17.51 17.68 17.64 17.86 17.82 17.77 17.39 17.5 17.3
Sample Output
151205.00
BUY GOOG
BUY IBM
BUY IBM
HOLD
SELL IBM
BUY MSFT
SELL MSFT
SELL GOOG
SELL IBM
复制代码

书中的方法: 状态: d(i,p) 表示经过 i 天后,资产组合为 p 时的现金的最大值,其中 p 是一个 n 元组, p_i \leq k_i 表示第 i 只股票有 p_i 手。 p_1+p_2+...+p_n \leq k 。因为 0 \leq p_i \leq 8 (在一种资产组合中,第 i 只股票最少有 0 手,最多有 8 手, 1 \leq k \leq 8 ),又因为 1 \leq n \leq 8 ,就是说一个 89 进制数可以表示,理论上最多只有 9^8 中可能性, int 类型可以保存。
状态转移:一共有三种决策: HOLD 、 BUY 和 SELL ,分别进行状态转移即可。
在购买股票的时候需要判断现金是否足够。
不能“反向定义”
九进制整数无法直接进行“买卖股票”的操作,需要解码成 n 元组,几乎每次状态转移都会涉及编码解码操作,状态转移时间大幅提升,最后导致超时。
解决方法是事先计算好所有可能的状态并编号。
超时程序:

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <bitset>
#include <sstream>
#include <map>

using namespace std;
const int maxn = 8 ;
const int maxm=100+2;
double c;//总现金金额
double price[maxn][maxm];
int m,n,kk,s[maxn],k[maxn];
char name[maxn][10];
map<int,double> d[maxm];//使用 $map$ 减小空间,经过 $i$ 天之后,<资产组合 现金最大值>
map<int ,int> opt[maxm],pre[maxm];//opt$ 记录经过哪种操作到达这个资产组合, $prev$ 记录上一步的资产组合。 $i$ 天后,<int, int>中键表示资产组合为 $p(int)$ 时的操作和之前的资产组合


int encode(int* portfolio)
{
    int hh=0;
    for(int i=0;i<n;i++){
        hh=hh*9+portfolio[i];
    }
    return hh;
}
/** \brief 解码
 *
 * \param h 需解码的编号
 * \param portfolio 解码结果放在这个地址所指的数组中
 * \return 资产组合中所有股票的手数
 *
 */
int decode(int h,int* portfolio)
{
    int totlot=0;
    for(int i=n-1;i>=0;i--)
    {
        portfolio[i]=h%9;
        totlot+=portfolio[i];
        h/=9;
    }
    return totlot;
}
/** \brief 更新资产组合和剩余现金
 *
 * \param oldh 旧的资产组合
 * \param day 经过 $day$ 天之后的情况
 * \param h 更新之后的资产组合
 * \param v 更新之后的剩余现金
 * \param o 使用的操作
 *
 */
void update(int oldh,int day,int h,double v,int o)
{
    //d$ 表示经过 $i$ 天之后,该资产组合的剩余现金最大值,所以要么 $h$ 资产组合没有计入,要么有更大的剩余值出现
    if(d[day].count(h)==0 || v>d[day][h]){
        opt[day][h]=o;
        pre[day][h]=oldh;
        d[day][h]=v;
    }
}
double dp()
{
    int portfolio[maxn];
    d[0][0]=c;
    for(int day=0;day<m;day++)
    {
        for(map<int ,double>::iterator it=d[day].begin();it != d[day].end();it++)
        {
            int h=it->first;
            double v=it->second;
            int totlot=decode(h,portfolio);
            update(h,day+1,h,v,0);
            for(int i=0;i<n;i++)
            {
                //如果可以买这只股票
                if(portfolio[i]<k[i] && totlot<kk && v>=price[i][day]-1e-3){
                    //每个源 $double$ 数据进行运算后,其精度会进一步发生改变,导致结果误差较大,而每天的价格 $0.01 ≤ c ≤ 100000000.00$ ,所以减去 $1e-3$ 不会对结果造成影响
                    portfolio[i]++;
                    update(h,day+1,encode(portfolio),v-price[i][day],i+1);
                    portfolio[i]--;
                }
                //如果可以卖出这只股票,即手中持有该股票
                if(portfolio[i]>0)
                {
                    portfolio[i]--;
                    update(h,day+1,encode(portfolio),v+price[i][day],-i-1);
                    portfolio[i]++;
                }
            }
        }
    }
    return d[m][0];
}

void print_ans(int m,int h){
    if(m<0)
        return;
    print_ans(m-1,pre[m][h]);
    if(!opt[m][h])
        cout<<"HOLD";
    else if(opt[m][h]>0)
    {
        cout<<"BUY ";
        cout<<name[opt[m][h]-1];
    }else{
        cout<<"SELL ";
        cout<<name[-opt[m][h]-1];
    }
    cout<<endl;
}

int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    int kase=0;
    while(cin>> c>>m>>n>>kk)
    {
        if(kase++)
            cout<<endl;
        for(int i=0;i<n;i++)
        {
            //name[i]本来指的就是一个数组的起始地址
            scanf("%s%d%d",name[i],&s[i],&k[i]);
            for(int j=0;j<m;j++)
            {
                cin>>price[i][j];
                price[i][j]*=s[i];
            }
        }
        double ans=dp();
        printf("%.2lf",ans);
        print_ans(m,0);
    }
    return 0;
}
复制代码

该程序会造成超时,原因是每次状态转移都会涉及编码解码操作。
解决方法是事先计算出所有可能的状态并编号,然后构造一个状态转移表。
这个方法需要首先估算状态总数的最大值,即求以下问题:
8 只股票,总股票数最多持有 8 手,共有多少种状态? 可以使用程序计算,使用深度优先搜索算法,最多有 12870 中状态

#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <bitset>
#include <sstream>
#include <map>

using namespace std;

int all;
void dfs(int stock,int totlot){
    if(stock==8)
    {
        all++;
    }
    else for(int i=0;totlot+i<=8;i++)
    {
        dfs(stock+1,totlot+i);
    }
}
int calculate(){
    all=0;
    dfs(0,0);
    return all;
}
int main()
{
cout<<calculate();
    return 0;
}
复制代码

高效率完整程序:

#define LOCAL
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <bitset>
#include <sstream>
#include <map>

using namespace std;
const int maxn = 8;
const int maxm = 100 + 2;
const double INF = 1e30;
const int maxstate = 13000;
double c; //总现金金额
double price[maxn][maxm];
int m, n, kk, s[maxn], k[maxn];
char name[maxn][10];
vector<vector<int>> states; //所有投资组合的集合
map<vector<int>, int> ID;
int buy_next[maxstate][maxn], sell_next[maxstate][maxn];

double d[maxm][maxstate]; //d$ 表示经过 $i$ 天之后,资产组合为 $j$ 的剩余现金最大值
int opt[maxm][maxstate], pre[maxm][maxstate];
//使用深度优先搜索算法遍历所有可能的状态
void dfs(int stock, vector<int> &lots, int totlot)
{
    if (stock == n)
    {
        ID[lots] = states.size();
        states.push_back(lots);
    }
    else
        for (int i = 0; i <= k[stock] && totlot + i <= kk; i++)
        {
            lots[stock] = i;
            dfs(stock + 1, lots, totlot + i);
        }
}
void init()
{
    memset(buy_next, -1, sizeof(buy_next));
    memset(sell_next, -1, sizeof(sell_next));
    vector<int> lots(n, 0);
    states.clear();
    ID.clear();
    dfs(0, lots, 0);
    //遍历所有可能的投资组合,计算所有投资组合买入股票 $i,卖出股票 $i$ 之后转移到的状态编号
    for (int s = 0; s < states.size(); s++)
    {
        int totlot = 0; //记录该投资组合的总的股票手数
        for (int i = 0; i < n; i++)
        {
            totlot += states[s][i];
        }
        for (int i = 0; i < n; i++)
        {
            //如果从股票手数上讲可以买入
            if (states[s][i] < k[i] && totlot < kk)
            {
                vector<int> newstate = states[s];
                newstate[i]++;
                buy_next[s][i] =
                    ID[newstate];
            }
            if (states[s][i] > 0)
            {
                vector<int> newstate = states[s];
                newstate[i]--;
                sell_next[s][i] =
                    ID[newstate];
            }
        }
    }
}
void update(int day, int oldh, int h, double v, int o)
{
    double &dd = d[day][h];
    if (v > dd)
    {
        dd = v;
        opt[day][h] = o;
        pre[day][h] = oldh;
        cout << day << " ";
        cout << h << " ";
        cout << v << " ";
        cout << o << " ";
        cout << oldh << endl;
    }
}
double dp()
{
    for (int day = 0; day <= m; day++)
    {
        for (int s = 0; s < states.size(); s++)
        {
            d[day][s] = -INF;
        }
    }

    vector<int> lots(n, 0);
    d[0][0] = c;
    for (int day = 0; day < m; day++)
    {
        for (int s = 0; s < states.size(); s++)
        {
            double v = d[day][s];
            if (v < -1)
                continue;
            update(day + 1, s, s, v, 0);
            for (int i = 0; i < n; i++)
            {
                //如果从股票手数上说是可以的,
                if (buy_next[s][i] > -1 && v >= price[i][day] - 1e-3)
                {
                    update(day + 1, s, buy_next[s][i], v - price[i][day], i + 1);
                }
                if (sell_next[s][i] > -1)
                {
                    update(day + 1, s, sell_next[s][i], v + price[i][day], -i - 1);
                }
            }
        }
    }
    return d[m][0];
}
void print_ans(int day, int h)
{
    if (day < 0)
        return;
    print_ans(day - 1, pre[day][h]);
    if (opt[day][h] == 0)
        cout << "HOLD";
    else if (opt[day][h] > 0)
    {
        cout << "BUY ";
        cout << name[opt[day][h] - 1];
    }
    else if (opt[day][h] < 0)
    {
        cout << "SELL ";
        cout << name[-opt[day][h] - 1];
    }
    cout << endl;
}
int main()
{
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif // LOCAL
    int kase = 0;
    while (cin >> c >> m >> n >> kk)
    {
        if (kase++)
            cout << endl;
        for (int i = 0; i < n; i++)
        {
            //name[i]本来指的就是一个数组的起始地址
            scanf("%s%d%d", name[i], &s[i], &k[i]);
            for (int j = 0; j < m; j++)
            {
                cin >> price[i][j];
                price[i][j] *= s[i];
            }
        }
        init();
        double ans = dp();
        printf("%.2lf\n", ans);
        print_ans(m, 0);
    }
    return 0;
}
复制代码

转载于:https://juejin.im/post/5a97cd16518825557e77c92e

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值