- 作为一个算法小白,本人第一次接触大模拟的题,本题的算法参考自:【CSP】202309-3 梯度求解
解题思路
1.输入处理
-
getchar();
:从标准输入读取一个字符。这里它的作用可能是用来“吃掉”(消耗)前一个输入后留下的换行符。确保getline
能正确读取到下一行文本。 -
getline(cin, op);
:从标准输入流cin
中读取一行文本,直到遇到换行符(用户按下回车键),然后将读取到的文本(不包括换行符)存储到之前声明的字符串变量中。
2.逐个处理表达式中的元素
(1)变量 x i x_i xi的表示
struct elem {
int index; // 变量索引,即x的下标
long long value; // 变量值
long long derivative; // 对应变量的导数值
};
- 注意,这里因为最后求的是导函数的值,而无需记录导数的形式。例如,对于
f
(
x
)
=
x
2
,
x
=
1
f(x)=x^2,x=1
f(x)=x2,x=1,其导函数
f
′
(
x
)
=
2
x
f'(x)=2x
f′(x)=2x,这里我们直接记录
f
′
(
1
)
=
2
f'(1)=2
f′(1)=2,即
derivative=2
。
(2)将数字字符串转换为长整型
long long str2ll(string a) {
int sign = 1; // 判断是正数还是负数
long long ans = 0;
if (a[0] == '-')
sign = -1;
for (int i = 0; i < a.length(); i++) {
if (a[i] != '-')
ans = 10 * ans + (a[i] - '0');
}
return ans * sign;
}
(3)使用 stringstream
逐个处理 op
字符串中的元素
-
std::stringstream
是一个流类,可以像输入/输出流一样操作字符串。允许把字符串分割成多个部分,根据空白字符(如空格、制表符等)来拆分原始字符串。 -
循环的作用是逐个读取
op
字符串中的每个以空格分隔的子字符串,并在每次迭代中处理这些子字符串。这种处理方式对于解析和执行基于逆波兰表示法(RPN)的算术表达式非常有效,因为它允许程序按照操作的顺序(从左到右)逐步计算表达式的结果。
stringstream ss(op);
string s;
while (ss >> s) {}
(4)逆波兰式的处理逻辑
- 题目所给的字符串只涉及:
x,x的索引,运算符+-*,常数
。 - 整理来看可以分为两类:变量+运算符,这也就明确了处理的逻辑。
-
变量 x i x_i xi,存入
elem
中。if (s[0] == 'x') { elem a; a.index = str2ll(s.substr(1, s.length() - 1)); // 得到变量下标 a.derivative = xIndex == a.index ? 1 : 0; // 该变量是否要被求偏导(导数是 1,否则为 0) a.value = value[a.index]; // 变量在给定的值数组中的值 st.push(a); // 将包含变量信息的结构体 a 压入栈中,以便后续计算表达式的值和导数(和数字运算一样) }
-
运算符,由于求导运算本质上还是算数运算,并且是给定了变量值,本质上还是我们之前遇到过的算数运算的规则:遇到运算符,移出栈顶的两个操作数,进行对应的运算,
res
用于保存运算结果。elem op2 = st.top(); st.pop(); elem op1 = st.top(); st.pop(); elem res;
-
由于求的是导数,这里的
+-*
不再是普通意义上的+-*
,要符合导数运算的规则。switch (s[0]) { case '+': { res.value = ((op1.value + op2.value) % MOD + MOD) % MOD; res.derivative = ((op1.derivative + op2.derivative) % MOD + MOD) % MOD; break; } case '-': { res.value = ((op1.value - op2.value) % MOD + MOD) % MOD; res.derivative = ((op1.derivative - op2.derivative) % MOD + MOD) % MOD; break; } case '*': { res.value = ((op1.value * op2.value) % MOD + MOD) % MOD; res.derivative = ((op1.derivative * op2.value + op1.value * op2.derivative) % MOD + MOD) % MOD; } } st.push(res);
-
常数,类似于非变量的 x i x_i xi。
else { elem a; a.value = str2ll(s); a.derivative = 0; st.push(a); }
-
- 最终,栈顶的结果即为运算结果。
3.完善代码
#include <iostream>
#include <vector>
#include <stack>
#include <sstream>
using namespace std;
// 定义一个结构体 elem,用于表示表达式中的元素
struct elem {
int index; // 变量索引
long long value; // 变量值
long long derivative; // 对应变量的导数
};
const long long MOD = 1000000007; // 模数
// 将字符串转换为长整型
long long str2ll(string a) {
int sign = 1; // 判断是正数还是负数
long long ans = 0;
if (a[0] == '-')
sign = -1;
for (int i = 0; i < a.length(); i++) {
if (a[i] != '-')
ans = 10 * ans + (a[i] - '0');
}
return ans * sign;
}
int main() {
int n, m;
cin >> n >> m; // 输入变量个数和表达式数量
string op;
getchar();
getline(cin, op); // 获取表达式字符串
vector<elem> expr;
stack<elem> st;
for (int i = 0; i < m; i++) {
int xIndex;
vector<long long> value(n + 1);
cin >> xIndex; // 输入变量xi,其余均视为常量
for (int j = 1; j <= n; j++)
cin >> value[j]; // 输入每个变量的值
stringstream ss(op);
string s;
while (ss >> s) {
// 判断是否是变量
if (s[0] == 'x') {
elem a;
a.index = str2ll(s.substr(1, s.length() - 1)); // 得到变量的索引
a.derivative = xIndex == a.index ? 1 : 0; // 变量对目标变量的导数是 1,否则为 0
a.value = value[a.index]; // 变量在给定的值数组中的值
st.push(a); // 将包含变量信息的结构体 a 压入栈中,以便后续计算表达式的值和导数
}
// 检查当前读取到的字符串 s 是否只有一个字符且为加号、减号或乘号。如果是这三个运算符之一,就执行相应的运算逻辑
else if (s.length() == 1 && (s[0] == '+' || s[0] == '-' || s[0] == '*')) {
// 处理运算符的逻辑
elem op2 = st.top();
st.pop();
elem op1 = st.top();
st.pop();
elem res;
switch (s[0]) {
case '+': {
res.value = ((op1.value + op2.value) % MOD + MOD) % MOD;
res.derivative = ((op1.derivative + op2.derivative) % MOD + MOD) % MOD;
break;
}
case '-': {
res.value = ((op1.value - op2.value) % MOD + MOD) % MOD;
res.derivative = ((op1.derivative - op2.derivative) % MOD + MOD) % MOD;
break;
}
case '*': {
res.value = ((op1.value * op2.value) % MOD + MOD) % MOD;
res.derivative = ((op1.derivative * op2.value + op1.value * op2.derivative) % MOD + MOD) % MOD;
}
}
st.push(res);
}
else {
elem a;
a.value = str2ll(s);
a.derivative = 0;
st.push(a);
}
}
long long ans = st.top().derivative;
cout << ((ans % MOD) + MOD) % MOD << endl; // 输出结果取模
}
return 0;
}