hdu5045 Contest 状压dp/KM

链接:http://acm.hdu.edu.cn/showproblem.php?pid=5045

Contest

Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)
Total Submission(s): 260    Accepted Submission(s): 113


Problem Description
In the ACM International Collegiate Programming Contest, each team consist of three students. And the teams are given 5 hours to solve between 8 and 12 programming problems.

On Mars, there is programming contest, too. Each team consist of N students. The teams are given M hours to solve M programming problems. Each team can use only one computer, but they can’t cooperate to solve a problem. At the beginning of the ith hour, they will get the ith programming problem. They must choose a student to solve this problem and others go out to have a rest. The chosen student will spend an hour time to program this problem. At the end of this hour, he must submit his program. This program is then run on test data and can’t modify any more.

Now, you have to help a team to find a strategy to maximize the expected number of correctly solved problems.

For each problem, each student has a certain probability that correct solve. If the ith student solve the jth problem, the probability of correct solve is Pij .

At any time, the different between any two students’ programming time is not more than 1 hour. For example, if there are 3 students and there are 5 problems. The strategy {1,2,3,1,2}, {1,3,2,2,3} or {2,1,3,3,1} are all legal. But {1,1,3,2,3},{3,1,3,1,2} and {1,2,3,1,1} are all illegal.

You should find a strategy to maximize the expected number of correctly solved problems, if you have know all probability
 

Input
The first line of the input is T (1 ≤ T ≤ 20), which stands for the number of test cases you need to solve.

The first line of each case contains two integers N ,M (1 ≤ N ≤ 10,1 ≤ M ≤ 1000),denoting the number of students and programming problem, respectively.

The next N lines, each lines contains M real numbers between 0 and 1 , the jth number in the ith line is Pij .
 

Output
For each test case, print a line “Case #t: ”(without quotes, t means the index of the test case) at the beginning. Then a single real number means the maximal expected number of correctly solved problems if this team follow the best strategy, to five digits after the decimal point. Look at the output for sample input for details.
 

Sample Input
1 2 3 0.6 0.3 0.4 0.3 0.7 0.9
 

Sample Output
Case #1: 2.20000
 

Source


题意:给定n个人 m个题目下面n*m的矩阵表示每个人解出每道题的概率,我们可以得到一个长为m的集合{1,2,3,1,2}代表每道题是谁解出的。

有众多集合,且获得这个集合有一个期望,求期望最大的那个集合 ( 的期望值是多少)

一个限制: 对于集合 {1,1,2,3,1} 这样是不合法的(即从[1,n]题必须是1-n的排列,然后[n+1, 2n]题也是一个排列)

题解:

状压dp,dp[i][j]表示解到第i道题时状态是j的最大期望值,j表示n个人与他们最小解题数的差值,容易知道只有二进制位是0的时候可以转移,二进制全1即为0。

代码:

/**
 * @author neko01
 */
//#pragma comment(linker, "/STACK:102400000,102400000")
#include <cstdio>
#include <cstring>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <queue>
#include <vector>
#include <cmath>
#include <set>
#include <map>
using namespace std;
typedef long long LL;
#define min3(a,b,c) min(a,min(b,c))
#define max3(a,b,c) max(a,max(b,c))
#define pb push_back
#define mp(a,b) make_pair(a,b)
#define clr(a) memset(a,0,sizeof a)
#define clr1(a) memset(a,-1,sizeof a)
#define dbg(a) printf("%d\n",a)
typedef pair<int,int> pp;
const double eps=1e-9;
const double pi=acos(-1.0);
const int N=10;
const int M=1025;
double a[N][M];
double dp[M][M];
int main()
{
    int t,cnt=0;
    scanf("%d",&t);
    while(t--)
    {
        int n,m;
        scanf("%d%d",&n,&m);
        for(int i=0;i<n;i++)
            for(int j=0;j<m;j++)
                scanf("%lf",&a[i][j]);
        clr(dp);
        for(int i=0;i<n;i++)
            dp[0][1<<i]=a[i][0];
        for(int i=1;i<m;i++)
        {
            for(int j=0;j<(1<<n);j++)
            {
                if(dp[i-1][j])
                {
                    for(int k=0;k<n;k++)
                    {
                        if((j&(1<<k))==0)
                        {
                            int tmp=j|(1<<k);
                            if(tmp==(1<<n)-1) tmp=0;
                            dp[i][tmp]=max(dp[i][tmp],dp[i-1][j]+a[k][i]);
                        }
                    }
                }
            }
        }
        double ans=0;
        for(int i=0;i<(1<<n);i++)
            ans=max(ans,dp[m-1][i]);
        printf("Case #%d: %.5lf\n",++cnt,ans);
    }
    return 0;
}

还可以用费用流或KM搞,很简单,做m/n次KM就行。

/**
 * @author neko01
 */
//#pragma comment(linker, "/STACK:102400000,102400000")
#include <cstdio>
#include <cstring>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <queue>
#include <vector>
#include <cmath>
#include <set>
#include <map>
using namespace std;
typedef long long LL;
#define min3(a,b,c) min(a,min(b,c))
#define max3(a,b,c) max(a,max(b,c))
#define pb push_back
#define mp(a,b) make_pair(a,b)
#define clr(a) memset(a,0,sizeof a)
#define clr1(a) memset(a,-1,sizeof a)
#define dbg(a) printf("%d\n",a)
typedef pair<int,int> pp;
const double eps=1e-9;
const double pi=acos(-1.0);
const int INF=0x3f3f3f3f;
const int N=15;
const int M=1005;
double a[N][M];
double g[N][N];
double lx[N],ly[N];
int match[N];
bool visx[N],visy[N];
double slack[N];
int n1,n2;
bool dfs(int x)
{
    visx[x]=1;
    for(int y=1;y<=n2;y++)
    {
        if(visy[y])
            continue;
        double t=lx[x]+ly[y]-g[x][y];
        if(fabs(t)<eps)
        {
            visy[y]=1;
            if (match[y]==-1||dfs(match[y]))
            {
                match[y] = x;
                return true;
            }
        }
        else if(slack[y]>t)
            slack[y]=t;
    }
    return false;
}
double KM()
{
    int i,j;
    memset(ly,0,sizeof(ly));
    memset(match,-1,sizeof(match));
    for (i=1;i<=n1;i++)
        for (j=1,lx[i]=-INF;j<=n2;j++)
            if(g[i][j]>lx[i])
                lx[i]=g[i][j];
    for(int x=1;x<=n1;x++)
    {
        for(i=1;i<=n2;i++)
            slack[i]=INF;
        while(true)
        {
            memset(visx,false,sizeof(visx));
            memset(visy,false,sizeof(visy));
            if(dfs(x))
                break;
            double t=INF;
            for (i=1;i<=n1;i++)
                if (!visy[i]&&t>slack[i])
                    t=slack[i];
            for (i=1;i<=n1;i ++)
                if (visx[i])
                    lx[i]-=t;
            for (i=1;i<=n2;i ++)
                if(visy[i])
                    ly[i]+=t;
                else
                    slack[i]-=t;
        }
    }
    double ans=0;
    for (i=1;i<=n2;i++)
        if(match[i]!=-1)
            ans+=g[match[i]][i];
    return ans;
}
int main()
{
    int t,cnt=0;
    scanf("%d",&t);
    while(t--)
    {
        int n,m;
        scanf("%d%d",&n,&m);
        clr(a);
        for(int i=0;i<n;i++)
            for(int j=0;j<m;j++)
                scanf("%lf",&a[i][j]);
        n1=n2=n;
        double ans=0;
        for(int i=0;i<m;i+=n)
        {
            for(int j=0;j<n;j++)
            {
                for(int k=i;k<i+n;k++)
                    g[j+1][k-i+1]=a[j][k];
            }
            ans+=KM();
        }
        printf("Case #%d: %.5lf\n",++cnt,ans);
    }
    return 0;
}


没有更多推荐了,返回首页