题目连接:
题目大意:
给出一个表达式,在当中添加一个括号,求得最大表达式的解
题目分析:
左括号一定在*号的后面或者表达式的开头,右括号一定在表达式的末尾或者*的前面
可以证明3*6+5*7一定小于3*(6+5)*7(通过乘法分配率可以证明)
加在加号两端,如果碰到乘法,不会影响运算顺序,碰到加法,因为加法交换律和结合律又不会影响最后结果
所以只需要枚举括号的位置,然后计算结果,算取最大的结果
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <stack>
#include <vector>
#define MAX 6007
using namespace std;
typedef long long LL;
char s[MAX];
char exp[MAX];
vector<int> pos;
stack<char> op;
stack<LL> num;
LL ans;
/*LL solve ( char s[] )
{
int m = strlen(s);
stack<char> op;
stack<LL> num;
for ( int i = 0 ; i < m ; i++ )
{
if ( s[i] == '*' )
{
if ( s[i+1] == '(' ) op.push('*');
else
{
int a = num.top();
num.pop();
int b = s[i+1]-48;
num.push ( a*b );
i++;
continue;
}
}
if ( s[i] == '(' )
op.push(s[i]);
if ( s[i] == '+' )
{
while ( !op.empty() && op.top() == '*' )
{
LL num1 = num.top();
num.pop();
LL num2 = num.top();
num.pop();
num.push ( num1*num2 );
op.pop();
}
op.push(s[i]);
}
if ( s[i] == ')' )
{
while ( op.top() != '(' )
{
LL num1 = num.top();
num.pop();
LL num2 = num.top();
num.pop();
if ( op.top() == '*' )
num.push(num1*num2);
else num.push(num1+num2);
op.pop();
}
}
if ( isdigit(s[i]) )
num.push(s[i]-48);
}
char temp;
while ( !op.empty() )
{
temp = op.top();
op.pop();
LL num1 = num.top();
num.pop();
LL num2 = num.top();
num.pop();
if ( temp == '*' )
num.push ( num1*num2 );
else num.push ( num1+num2 );
}
return num.top();
}*/
void calc ( )
{
LL num1 = num.top();
num.pop();
LL num2 = num.top();
num.pop();
if ( op.top() == '*' )
num.push( num1*num2);
else num.push( num1+num2);
op.pop();
}
void Clear ( )
{
while (!op.empty()) op.pop();
while (!num.empty()) num.pop();
}
LL solve ( char s[] )
{
int m = strlen(s);
Clear();
for ( int i = 0 ; i < m ; i++ )
{
if ( isdigit(s[i]) )
num.push ( s[i] - 48 );
else if ( s[i] == '(' )
op.push(s[i]);
else if ( s[i] == ')' )
{
while ( op.top() != '(' )
calc();
op.pop();
}
else
{
if ( s[i] == '+' )
while ( !op.empty() && op.top() == '*' )
calc();
op.push ( s[i] );
}
}
while ( !op.empty())calc();
return num.top();
}
int main ( )
{
while ( ~scanf ( "%s" , s ) )
{
ans = solve(s);
//cout << ans << endl;
pos.clear();
int len = strlen(s);
pos.push_back(-1);
for ( int i = 0 ; i < len ; i++ )
if ( s[i] == '*' )
pos.push_back(i);
pos.push_back(len);
int n = pos.size();
for ( int i = 0 ; i < n-1;i++)
for ( int j = i+1 ; j < n ; j++ )
{
int index = 0;
for ( int k = 0 ; k <= pos[i] ; k++ )
exp[index++] = s[k];
exp[index++] = '(';
for ( int k = pos[i]+1 ; k < pos[j]; k++ )
exp[index++] = s[k];
exp[index++] = ')';
for ( int k = pos[j] ; k <=len ; k++ )
exp[index++] = s[k];
ans = max ( ans , solve(exp) );
}
printf ( "%I64d\n" , ans );
}
}