数位 DP

简介 #

如何统计区间 $[l,r]$ 中有多少整数符合某条件?

  1. 暴力算法,枚举 $[l,r]$ 中的每一个整数,逐个判断是否满足条件,此方法遇大数据必 $gg$.
  2. 优雅地使用数位 DP.

问题 #

统计区间 $[l,r]$($0≤l<r≤100$)中有多少整数符合「相邻两个数字之差 $≥2$」.

预处理 #

采用「试填法」:从个位填到最高位,如果第 $d$ 位填了 $i$,那么第 $d+1$ 位只能填 $[0,i-2]$ 或 $[i+2,9]$ 中的整数.

$f[i,d]$ 表示「所有最高位为 $i$ 的 $d$ 位数中,符合条件的个数」. 通过给定条件可推出:

$$ f[i,d]=\sum_{|k-i|≥2} f[k,d-1] $$

初始条件 $f[i,1]=1$ 计算顺序 $f[0→9,2→n]$
时间复杂度 $O(10^2\log{n})$
int f[][];

for(int i = 0; i <= 9; i ++) f[i][1] = 1; // 初始条件
for(int d = 2; d <= N; d ++) // N : 位数的上限,N ≈ log(r)
	for(int i = 0; i <= 9; i ++)
		for(int k = 0; k <= 9; k ++)
			if(abs(k - i) >= 2)
				f[i][d] += f[k][d - 1];

数位统计 #

考虑 前缀和 思想:

$dp(n)$ 表示 $[0,n]$ 中有多少个数满足条件. $[l,r]$ 中符合条件的个数 $=dp(r)-dp(l-1)$.

$dp(n)$ 的实现步骤:

step 1 #

提取 $n$ 每一位上的数字,存入数组 $at[ \ ]$:

int cap = 0, at[];
// cap : n 的位数;at[i] : n 的第 i 位数字
while(n) at[++ cap] = n % 10, n /= 10;

step 2 #

所有 $1\cdots cap-1$ 位数都被包含于 $[0,n]$ 区间中. 统计它们中符合条件的个数:

int ans = 0; // ans : 符合条件的个数
for(int d = 1; d < cap; d ++) // d : 位数
    for(int i = 1; i <= 9; i ++) // i : 最高位填的数
        ans += f[i][d];

step 3 #

统计所有 $cap$ 位数中符合条件的个数.

使用「试填法」,枚举 $d=cap→1$,从最高位填到最低位,并使填的数 $<n$:

  • 若 $d=cap$,该位不能填 $0$,只能填 $1\cdots at[d]-1$. 统计符合条件的情况;

  • 若 $d\not=cap$,该位只能填 $0\cdots at[d]-1$. 统计符合条件的情况;

    • 若此时 $|at[d+1]-at[d]|<2$,下一位无论怎么填都不符合条件,跳出循环;

    • 若上一步未跳出循环且 $d=1$,说明 $n$ 本身也符合条件. 但「试填法」最多只填到 $n-1$,故还要多算一个.

for(int d = cap; d >= 1; d --) { // d : 当前填到第 d 位
    for(int i = (d == cap); i < at[d]; i ++)
        if(abs(at[d + 1] - i) >= 2) ans += f[i][d];
    if(d != cap && abs(at[d + 1] - at[d]) < 2) break;
    if(d == 1) ans ++;
}

模板 #

int dp(int n) { // 求 [0, n] 中有几个数符合条件
    if(n <= 0) return !n; // 特判

    int cap = 0, ans = 0, at[];
    while(n) at[++ cap] = n % 10, n /= 10;

    for(int d = 1; d < cap; d ++)
        for(int i = 1; i <= 9; i ++)
            ans += f[i][d];
    
    for(int d = cap; d >= 1; d --) {
        for(int i = (d == cap); i < at[d]; i ++)
            if(abs(last - i) >= 2) ans += f[i][d]; // 条件按照题目的需要
        if(d != cap && abs(at[d + 1] - at[d]) < 2) break;
        if(d == 1) ans ++;
    }

    return ++ ans;
}