题目链接:https://vjudge.net/problem/HDU-4507
题意:找出一个区间内与7无关的数
解题思路:
先说一下我自己的错误写法:
通过记录所有的数位和、对7取余后的结果以及当前的数值,当最后的情况是sum%7=0并且res%7=0时,将当前数值he的平方返回
但是这样做的时候记忆化搜索的时候就会出现错误,因为不同前缀可以通过取余得到相同的sum和res,但是最后在对he进行平方相加的时候前缀不同就会影响到结果的不同,所以这种方法不可行。
ll dfs(int pos, int limit, int sum, int res, ll he) {
if (!limit && dp[pos][sum][res] != -1) {
return dp[pos][sum][res];
}
if (pos == 0) {
if (sum == 0)
return 0;
if (res == 0)
return 0;
return he % mod * he % mod;
}
int up = limit ? num[pos] : 9;
ll ans = 0;
for (int i = 0; i <= up; i++) {
if (i == 7)
continue;
ans += dfs(pos - 1, limit && i == up, (sum + i)%7, (res * 10 + i) % 7,(he % mod * 10 % mod + i) % mod) % mod;
ans %= mod;
}
if (!limit)
return dp[pos][sum][res] = ans % mod;
else
return ans % mod;
}
正解:
本题主要难点在于求的是平方和相加,我们现在假设一个数的前缀为x 后缀为y,整体的平方和就是(x+y)^ 2=x^ 2+ 2xy+y^ 2,那么前面那种错误用法就可以利用起来,对于同一种sum和res情况下,它们的前缀可能不同,但是后缀一定相同,所以可以从后往前来逆推计算平方和,即先求出y和y^2的值,然后利用上面那个公式来计算当前位置处的平方和
每次我们采用整体处理的方法,相同前缀x对应cnt个后缀yi,所以我们整体求出yi的和,以及yi^ 2的和就可以转化成cnt*(x^ 2)+2*x * sum(yi)+ sum(yi^ 2),所以使用dp[pos][sum][res]来保存一个结构t体,进行记忆化搜索。
struct node {
ll cnt; //后缀满足条件的数量
ll sum; //记录当前和
ll qsum; //平方和
node() {
this->cnt = 0;
this->sum = 0;
this->qsum = 0;
}
node(int c, ll s, ll q) {
this->cnt = c;
this->sum = s;
this->qsum = q;
}
};
node dfs(int pos, int limit, int sum, int res) {
if (!limit && dp[pos][sum][res].cnt) {
return dp[pos][sum][res];
}
if (pos == 0) {
if (sum == 0)
return node(0,0,0);
if (res == 0)
return node(0,0,0);
return node(1,0,0);
}
int up = limit ? num[pos] : 9;
node ans = node(0, 0, 0);
for (int i = 0; i <= up; i++) {
if (i == 7)
continue;
//注意对于当前的前缀i要*10的pos次方
ll tmpi = i * ten[pos - 1] % mod;
node tmp = dfs(pos - 1, limit && i == up, (sum + i) % 7, (res * 10 + i) % 7);
//个数处理,直接相加就可以
ans.cnt += tmp.cnt;
ans.cnt %= mod;
//平方和处理,使用上述公式
ans.qsum += ((tmp.cnt * tmpi %mod * tmpi % mod + 2 * tmp.sum % mod * tmpi % mod) % mod + tmp.qsum) % mod;
ans.qsum %= mod;
//当前满足条件的个数的和
ans.sum += (tmp.sum + tmp.cnt * tmpi % mod) % mod;
ans.sum %= mod;
}
if (!limit)
return dp[pos][sum][res] = ans;
else
return ans;
}
AC代码:
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<cstdio>
#include<string>
#include<math.h>
#include<string.h>
using namespace std;
#define ll long long
const int mod = 1e9 + 7;
struct node {
ll cnt;
ll sum; //记录当前和
ll qsum; //平方和
node() {
this->cnt = 0;
this->sum = 0;
this->qsum = 0;
}
node(int c, ll s, ll q) {
this->cnt = c;
this->sum = s;
this->qsum = q;
}
};
int t;
ll l, r;
int num[20];
ll ten[20];
node dp[20][20][20];
node dfs(int pos, int limit, int sum, int res) {
if (!limit && dp[pos][sum][res].cnt) {
return dp[pos][sum][res];
}
if (pos == 0) {
if (sum == 0)
return node(0,0,0);
if (res == 0)
return node(0,0,0);
return node(1,0,0);
}
int up = limit ? num[pos] : 9;
node ans = node(0, 0, 0);
for (int i = 0; i <= up; i++) {
if (i == 7)
continue;
ll tmpi = i * ten[pos - 1] % mod;
node tmp = dfs(pos - 1, limit && i == up, (sum + i) % 7, (res * 10 + i) % 7);
ans.cnt += tmp.cnt;
ans.cnt %= mod;
ans.qsum += ((tmp.cnt * tmpi %mod * tmpi % mod + 2 * tmp.sum % mod * tmpi % mod) % mod + tmp.qsum) % mod;
ans.qsum %= mod;
ans.sum += (tmp.sum + tmp.cnt * tmpi % mod) % mod;
ans.sum %= mod;
}
if (!limit)
return dp[pos][sum][res] = ans;
else
return ans;
}
ll solve(ll x) {
int cnt = 0;
while (x) {
num[++cnt] = x % 10;
x /= 10;
}
node res = dfs(cnt, 1, 0, 0);
return res.qsum;
}
int main() {
ten[0] = 1;
for (int i = 1; i <= 18; i++)
ten[i] = ten[i - 1] * 10;
cin >> t;
while (t--) {
cin >> l >> r;
cout << (solve(r) + mod - solve(l - 1)) % mod << endl;
}
return 0;
}