题目链接
题意
定义和7有关的数字是满足下列条件之一的数字:
- 整数中某一位是7;
- 整数的每一位加起来的和是7的整数倍;
- 这个整数是7的整数倍;
给一个区间
[L,R]
,求区间内所有和7无关的数字的平方和。对1e9+7
取模。
数据范围:
1≤L≤R≤1018
分析
这个题目比较难了。。。我觉得。
无论是求和7有关的还是和7无关的数字个数都是比较简单的,但是求平方和就比较麻烦了。
仍然满足区间减法。我们要对dfs
的返回结果(低位的情况已经解决)记录三个属性(用结构体保存):
cnt
: 和7无关的数字个数sum
: 和7无关的所有数字之和sqsum
: 和7无关的所有数字的平方和
记当前答案为
ret
,dfs
返回答案为:
nxt
,当前是
pos
位并且枚举的数字是
i
,pw10[p]
表示
ret.cnt+=nxt.cnt;
ret.sum+=nxt.sum+i*pw10[p]*nxt.cnt;
ret.sqsum
现在我们考虑 pos 位数: i∗10pos+x ,平方得: (i2∗102∗pos+x2+2∗10pos∗i∗x) ,但是 x 肯可能有很多个。所以得:
∑(i2∗102∗pos+x2+2∗10pos∗i∗x),其中x是所有低位符合条件的数
ret.sum+=∑(i2∗102∗pos∗nxt.cnt),x一共nxt.cnt个
ret.sum+=∑x2=nxt.sqsum
ret.sum+=∑2∗10pos∗i∗x=2∗10pos∗i∗∑x=2∗10pos∗i∗nxt.sum
注意别爆long long
。对于dfs
的函数参数我们需要存的是高位数字和和高位数字对7取模后的结果,这样子在dfs
的最后一层来判断是否和7有关。
Code
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
using namespace std;
typedef long long ll;
const ll mod = (ll)(1e9 + 7);
int T, digit[20], vis[20][10][10];
ll L, R, pw10[20];
struct Node {
ll cnt, sum, sqsum;
Node () {}
Node (ll _cnt, ll _sum, ll _sqsum) : cnt(_cnt), sum(_sum), sqsum(_sqsum) { }
} dp[20][10][10];
Node dfs(int pos, int sum_rem, int num_rem, int limit)
{ //sum_ren:高位数字和对7取模余数 num_rem:高位数字对7取模余数
if (pos == -1) {
if (sum_rem == 0 || num_rem == 0) return Node(0, 0, 0);
else return Node(1, 0, 0);
}
if (!limit && vis[pos][sum_rem][num_rem]) return dp[pos][sum_rem][num_rem];
int last = limit ? digit[pos] : 9;
Node ret = Node(0, 0, 0);
for (int i = 0; i <= last; ++i) {
if (i == 7) continue;
Node nxt = dfs(pos - 1, (sum_rem + i) % 7, (num_rem * 10 + i) % 7, limit && (i == last));
ret.cnt = (ret.cnt + nxt.cnt) % mod;
ret.sum = ((ret.sum + pw10[pos] * i % mod * nxt.cnt % mod) % mod + nxt.sum) % mod;
ret.sqsum = (ret.sqsum + nxt.sqsum) % mod;
ret.sqsum = (ret.sqsum + pw10[pos] * pw10[pos] % mod * i * i % mod * nxt.cnt % mod) % mod;
ret.sqsum = (ret.sqsum + pw10[pos] * 2 * i % mod * nxt.sum % mod) % mod;
}
if (!limit) {
vis[pos][sum_rem][num_rem] = 1;
dp[pos][sum_rem][num_rem] = ret;
}
return ret;
}
ll solve(ll x)
{
memset(digit, 0, sizeof (digit));
int len = 0;
while (x) {
digit[len++] = x % 10;
x /= 10;
}
return dfs(len - 1, 0, 0, 1).sqsum;
}
int main()
{
pw10[0] = 1;
for (int i = 1; i <= 18; ++i) { pw10[i] = pw10[i - 1] * 10 % mod; }
memset(vis, 0, sizeof (vis));
scanf("%d", &T);
while (T--) {
scanf("%lld%lld", &L, &R);
ll ans = (solve(R) - solve(L - 1)) % mod;
printf("%lld\n", (ans + mod) % mod);
}
return 0;
}