title: 动态规划之KMP字符匹配算法
author: 远方
tags:

  • LeetCode
  • 算法
    categories:
  • LeetCode破局攻略
    date: 2016-01-01 19:20:00

动态规划之KMP字符匹配算法

KMP 算法(Knuth-Morris-Pratt 算法)是一个著名的字符串匹配算法,效率很高,但是确实有点复杂。
很多读者抱怨 KMP 算法无法理解,这很正常,想到大学教材上关于 KMP 算法的讲解,也不知道有多少未来的 Knuth、Morris、Pratt 被提前劝退了。有一些优秀的同学通过手推 KMP 算法的过程来辅助理解该算法,这是一种办法,不过本文要从逻辑层面帮助读者理解算法的原理。十行代码之间,KMP 灰飞烟灭。
先在开头约定,本文用 pat 表示模式串,长度为 Mtxt 表示文本串,长度为 N。KMP 算法是在 txt 中查找子串 pat,如果存在,返回这个子串的起始索引,否则返回 -1
为什么我认为 KMP 算法就是个动态规划问题呢,等会再解释。对于动态规划,之前多次强调了要明确 dp 数组的含义,而且同一个问题可能有不止一种定义 dp 数组含义的方法,不同的定义会有不同的解法。
读者见过的 KMP 算法应该是,一波诡异的操作处理 pat 后形成一个一维的数组 next,然后根据这个数组经过又一波复杂操作去匹配 txt。时间复杂度 O(N),空间复杂度 O(M)。其实它这个 next 数组就相当于 dp 数组,其中元素的含义跟 pat 的前缀和后缀有关,判定规则比较复杂,不好理解。本文则用一个二维的 dp 数组(但空间复杂度还是 O(M)),重新定义其中元素的含义,使得代码长度大大减少,可解释性大大提高
PS:本文的代码参考《算法4》,原代码使用的数组名称是 dfa(确定有限状态机),因为我们的公众号之前有一系列动态规划的文章,就不说这么高大上的名词了,我对书中代码进行了一点修改,并沿用 dp 数组的名称。

一、KMP 算法概述

首先还是简单介绍一下 KMP 算法和暴力匹配算法的不同在哪里,难点在哪里,和动态规划有啥关系。
暴力的字符串匹配算法很容易写,看一下它的运行逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 暴力匹配(伪码)
int search(String pat, String txt) {
int M = pat.length;
int N = txt.length;
for (int i = 0; i <= N - M; i++) {
int j;
for (j = 0; j < M; j++) {
if (pat[j] != txt[i+j])
break;
}
// pat 全都匹配了
if (j == M) return i;
}
// txt 中不存在 pat 子串
return -1;
}

对于暴力算法,如果出现不匹配字符,同时回退 txtpat 的指针,嵌套 for 循环,时间复杂度 $O(MN)$,空间复杂度$O(1)$。最主要的问题是,如果字符串中重复的字符比较多,该算法就显得很蠢。
比如 txt = “aaacaaab” pat = “aaab”:
brutal
很明显,pat 中根本没有字符 c,根本没必要回退指针 i,暴力解法明显多做了很多不必要的操作。
KMP 算法的不同之处在于,它会花费空间来记录一些信息,在上述情况中就会显得很聪明:
kmp1
再比如类似的 txt = “aaaaaaab” pat = “aaab”,暴力解法还会和上面那个例子一样蠢蠢地回退指针 i,而 KMP 算法又会耍聪明:
kmp2
因为 KMP 算法知道字符 b 之前的字符 a 都是匹配的,所以每次只需要比较字符 b 是否被匹配就行了。
KMP 算法永不回退 txt 的指针 i,不走回头路(不会重复扫描 txt),而是借助 dp 数组中储存的信息把 pat 移到正确的位置继续匹配,时间复杂度只需 O(N),用空间换时间,所以我认为它是一种动态规划算法。
KMP 算法的难点在于,如何计算 dp 数组中的信息?如何根据这些信息正确地移动 pat 的指针?这个就需要确定有限状态自动机来辅助了,别怕这种高大上的文学词汇,其实和动态规划的 dp 数组如出一辙,等你学会了也可以拿这个词去吓唬别人。
还有一点需要明确的是:计算这个 dp 数组,只和 pat 串有关。意思是说,只要给我个 pat,我就能通过这个模式串计算出 dp 数组,然后你可以给我不同的 txt,我都不怕,利用这个 dp 数组我都能在 O(N) 时间完成字符串匹配。
具体来说,比如上文举的两个例子:

1
2
3
4
txt1 = "aaacaaab" 
pat = "aaab"
txt2 = "aaaaaaab"
pat = "aaab"

我们的 txt 不同,但是 pat 是一样的,所以 KMP 算法使用的 dp 数组是同一个。
只不过对于 txt1 的下面这个即将出现的未匹配情况:

dp 数组指示 pat 这样移动:

PS:这个j 不要理解为索引,它的含义更准确地说应该是状态(state),所以它会出现这个奇怪的位置,后文会详述。
而对于 txt2 的下面这个即将出现的未匹配情况:

dp 数组指示 pat 这样移动:

明白了 dp 数组只和 pat 有关,那么我们这样设计 KMP 算法就会比较漂亮:

1
2
3
4
5
6
7
8
9
10
11
12
13
public class KMP {
private int[][] dp;
private String pat;
public KMP(String pat) {
this.pat = pat;
// 通过 pat 构建 dp 数组
// 需要 O(M) 时间
}
public int search(String txt) {
// 借助 dp 数组去匹配 txt
// 需要 O(N) 时间
}
}

这样,当我们需要用同一 pat 去匹配不同 txt 时,就不需要浪费时间构造 dp 数组了:

1
2
3
KMP kmp = new KMP("aaab");
int pos1 = kmp.search("aaacaaab"); //4
int pos2 = kmp.search("aaaaaaab"); //4

二、状态机概述

为什么说 KMP 算法和状态机有关呢?是这样的,我们可以认为 pat 的匹配就是状态的转移。比如当 pat = “ABABC”:

如上图,圆圈内的数字就是状态,状态 0 是起始状态,状态 5(pat.length)是终止状态。开始匹配时 pat 处于起始状态,一旦转移到终止状态,就说明在 txt 中找到了 pat。比如说当前处于状态 2,就说明字符 “AB” 被匹配:

另外,处于不同状态时,pat 状态转移的行为也不同。比如说假设现在匹配到了状态 4,如果遇到字符 A 就应该转移到状态 3,遇到字符 C 就应该转移到状态 5,如果遇到字符 B 就应该转移到状态 0:

具体什么意思呢,我们来一个个举例看看。用变量 j 表示指向当前状态的指针,当前 pat 匹配到了状态 4:

如果遇到了字符 “A”,根据箭头指示,转移到状态 3 是最聪明的:

如果遇到了字符 “B”,根据箭头指示,只能转移到状态 0(一夜回到解放前):

如果遇到了字符 “C”,根据箭头指示,应该转移到终止状态 5,这也就意味着匹配完成:

当然了,还可能遇到其他字符,比如 Z,但是显然应该转移到起始状态 0,因为 pat 中根本都没有字符 Z:

这里为了清晰起见,我们画状态图时就把其他字符转移到状态 0 的箭头省略,只画 pat 中出现的字符的状态转移:

KMP 算法最关键的步骤就是构造这个状态转移图。要确定状态转移的行为,得明确两个变量,一个是当前的匹配状态,另一个是遇到的字符;确定了这两个变量后,就可以知道这个情况下应该转移到哪个状态。
下面看一下 KMP 算法根据这幅状态转移图匹配字符串 txt 的过程:

请记住这个 GIF 的匹配过程,这就是 KMP 算法的核心逻辑
为了描述状态转移图,我们定义一个二维 dp 数组,它的含义如下:

1
2
3
4
5
6
7
8
9
10
dp[j][c] = next
0 <= j < M,代表当前的状态
0 <= c < 256,代表遇到的字符(ASCII 码)
0 <= next <= M,代表下一个状态
dp[4]['A'] = 3 表示:
当前是状态 4,如果遇到字符 A,
pat 应该转移到状态 3
dp[1]['B'] = 2 表示:
当前是状态 1,如果遇到字符 B,
pat 应该转移到状态 2

根据我们这个 dp 数组的定义和刚才状态转移的过程,我们可以先写出 KMP 算法的 search 函数代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public int search(String txt) {
int M = pat.length();
int N = txt.length();
// pat 的初始态为 0
int j = 0;
for (int i = 0; i < N; i++) {
// 当前是状态 j,遇到字符 txt[i],
// pat 应该转移到哪个状态?
j = dp[j][txt.charAt(i)];
// 如果达到终止态,返回匹配开头的索引
if (j == M) return i - M + 1;
}
// 没到达终止态,匹配失败
return -1;
}

到这里,应该还是很好理解的吧,dp 数组就是我们刚才画的那幅状态转移图,如果不清楚的话回去看下 GIF 的算法演进过程。下面讲解:如何通过 pat 构建这个 dp 数组?

三、构建状态转移图

回想刚才说的:要确定状态转移的行为,必须明确两个变量,一个是当前的匹配状态,另一个是遇到的字符,而且我们已经根据这个逻辑确定了 dp 数组的含义,那么构造 dp 数组的框架就是这样:

1
2
3
for 0 <= j < M: # 状态
for 0 <= c < 256: # 字符
dp[j][c] = next

这个 next 状态应该怎么求呢?显然,如果遇到的字符 cpat[j] 匹配的话,状态就应该向前推进一个,也就是说 next = j + 1,我们不妨称这种情况为状态推进

如果字符 cpat[j] 不匹配的话,状态就要回退(或者原地不动),我们不妨称这种情况为状态重启

那么,如何得知在哪个状态重启呢?解答这个问题之前,我们再定义一个名字:影子状态(我编的名字),用变量 X 表示。所谓影子状态,就是和当前状态具有相同的前缀。比如下面这种情况:

当前状态 j = 4,其影子状态为 X = 2,它们都有相同的前缀 “AB”。因为状态 X 和状态 j 存在相同的前缀,所以当状态 j 准备进行状态重启的时候(遇到的字符 cpat[j] 不匹配),可以通过 X 的状态转移图来获得最近的重启位置
比如说刚才的情况,如果状态 j 遇到一个字符 “A”,应该转移到哪里呢?首先只有遇到 “C” 才能推进状态,遇到 “A” 显然只能进行状态重启。状态 j 会把这个字符委托给状态 X 处理,也就是 dp[j]['A'] = dp[X]['A']**:

为什么这样可以呢?因为:既然 j 这边已经确定字符 “A” 无法推进状态,
只能回退,而且 KMP 就是要尽可能少的回退**,以免多余的计算。那么 j 就可以去问问和自己具有相同前缀的 X,如果 X 遇见 “A” 可以进行「状态推进」,那就转移过去,因为这样回退最少。

当然,如果遇到的字符是 “B”,状态 X 也不能进行「状态推进」,只能回退,j 只要跟着 X 指引的方向回退就行了:

你也许会问,这个 X 怎么知道遇到字符 “B” 要回退到状态 0 呢?因为 X 永远跟在 j 的身后,状态 X 如何转移,在之前就已经算出来了。动态规划算法不就是利用过去的结果解决现在的问题吗?
这样,我们就细化一下刚才的框架代码:

1
2
3
4
5
6
7
8
9
10
int X # 影子状态
for 0 <= j < M:
for 0 <= c < 256:
if c == pat[j]:
# 状态推进
dp[j][c] = j + 1
else:
# 状态重启
# 委托 X 计算重启位置
dp[j][c] = dp[X][c]

四、代码实现

如果之前的内容你都能理解,恭喜你,现在就剩下一个问题:影子状态 X 是如何得到的呢?下面先直接看完整代码吧。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public class KMP {
private int[][] dp;
private String pat;
public KMP(String pat) {
this.pat = pat;
int M = pat.length();
// dp[状态][字符] = 下个状态
dp = new int[M][256];
// base case
dp[0][pat.charAt(0)] = 1;
// 影子状态 X 初始为 0
int X = 0;
// 当前状态 j 从 1 开始
for (int j = 1; j < M; j++) {
for (int c = 0; c < 256; c++) {
if (pat.charAt(j) == c)
dp[j][c] = j + 1;
else
dp[j][c] = dp[X][c];
}
// 更新影子状态
X = dp[X][pat.charAt(j)];
}
}
public int search(String txt) {...}
}

先解释一下这一行代码:

1
2
// base case
dp[0][pat.charAt(0)] = 1;

这行代码是 base case,只有遇到 pat[0] 这个字符才能使状态从 0 转移到 1,遇到其它字符的话还是停留在状态 0(Java 默认初始化数组全为 0)。
影子状态 X 是先初始化为 0,然后随着 j 的前进而不断更新的。下面看看到底应该**如何更新影子状态 X**:

1
2
3
4
5
6
7
8
int X = 0;
for (int j = 1; j < M; j++) {
...
// 更新影子状态
// 当前是状态 X,遇到字符 pat[j],
// pat 应该转移到哪个状态?
X = dp[X][pat.charAt(j)];
}

更新 X 其实和 search 函数中更新状态 j 的过程是非常相似的:

1
2
3
4
5
6
7
int j = 0;
for (int i = 0; i < N; i++) {
// 当前是状态 j,遇到字符 txt[i],
// pat 应该转移到哪个状态?
j = dp[j][txt.charAt(i)];
...
}

其中的原理非常微妙,注意代码中 for 循环的变量初始值,可以这样理解:后者是在 txt 中匹配 pat,前者是在 pat 中匹配 pat[1..end],状态 X 总是落后状态 j 一个状态,与 j 具有最长的相同前缀。所以我把 X 比喻为影子状态,似乎也有一点贴切。
另外,构建 dp 数组是根据 base case dp[0][..] 向后推演。这就是我认为 KMP 算法就是一种动态规划算法的原因。
下面来看一下状态转移图的完整构造过程,你就能理解状态 X 作用之精妙了:

至此,KMP 算法的核心终于写完啦啦啦啦!看下 KMP 算法的完整代码吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
public class KMP {
private int[][] dp;
private String pat;
public KMP(String pat) {
this.pat = pat;
int M = pat.length();
// dp[状态][字符] = 下个状态
dp = new int[M][256];
// base case
dp[0][pat.charAt(0)] = 1;
// 影子状态 X 初始为 0
int X = 0;
// 构建状态转移图(稍改的更紧凑了)
for (int j = 1; j < M; j++) {
for (int c = 0; c < 256; c++)
dp[j][c] = dp[X][c];
dp[j][pat.charAt(j)] = j + 1;
// 更新影子状态
X = dp[X][pat.charAt(j)];
}
}
public int search(String txt) {
int M = pat.length();
int N = txt.length();
// pat 的初始态为 0
int j = 0;
for (int i = 0; i < N; i++) {
// 计算 pat 的下一个状态
j = dp[j][txt.charAt(i)];
// 到达终止态,返回结果
if (j == M) return i - M + 1;
}
// 没到达终止态,匹配失败
return -1;
}
}

经过之前的详细举例讲解,你应该可以理解这段代码的含义了,当然你也可以把 KMP 算法写成一个函数。核心代码也就是两个函数中 for 循环的部分,数一下有超过十行吗?

五、最后总结

传统的 KMP 算法是使用一个一维数组 next 记录前缀信息,而本文是使用一个二维数组 dp 以状态转移的角度解决字符匹配问题,但是空间复杂度仍然是 O(256M) = O(M)。
pat 匹配 txt 的过程中,只要明确了「当前处在哪个状态」和「遇到的字符是什么」这两个问题,就可以确定应该转移到哪个状态(推进或回退)。
对于一个模式串 pat,其总共就有 M 个状态,对于 ASCII 字符,总共不会超过 256 种。所以我们就构造一个数组 dp[M][256] 来包含所有情况,并且明确 dp 数组的含义:
dp[j][c] = next 表示,当前是状态 j,遇到了字符 c,应该转移到状态 next
明确了其含义,就可以很容易写出 search 函数的代码。
对于如何构建这个 dp 数组,需要一个辅助状态 X,它永远比当前状态 j 落后一个状态,拥有和 j 最长的相同前缀,我们给它起了个名字叫「影子状态」。
在构建当前状态 j 的转移方向时,只有字符 pat[j] 才能使状态推进(dp[j][pat[j]] = j+1);而对于其他字符只能进行状态回退,应该去请教影子状态 X 应该回退到哪里(dp[j][other] = dp[X][other],其中 other 是除了 pat[j] 之外所有字符)。
对于影子状态 X,我们把它初始化为 0,并且随着 j 的前进进行更新,更新的方式和 search 过程更新 j 的过程非常相似(X = dp[X][pat[j]])。
KMP 算法也就是动态规划那点事,我们的公众号文章目录有一系列专门讲动态规划的,而且都是按照一套框架来的,无非就是描述问题逻辑,明确 dp 数组含义,定义 base case 这点破事。希望这篇文章能让大家对动态规划有更深的理解。

上一篇:贪心算法之区间调度问题
下一篇:团灭 LeetCode 股票买卖问题
目录

动态规划详解

这篇文章是我们号半年前一篇 200 多赞赏的成名之作「动态规划详解」的进阶版。由于账号迁移的原因,旧文无法被搜索到,所以我润色了本文,并添加了更多干货内容,希望本文成为解决动态规划的一部「指导方针」。

再说句题外话,我们的公众号开号至今写了起码十几篇文章拆解动态规划问题,我都整理到了公众号菜单的「文章目录」中,它们都提到了动态规划的解题框架思维,本文就系统总结一下。这段时间本人也从非科班小白成长到刷通半个 LeetCode,所以我总结的套路可能不适合各路大神,但是应该适合大众,毕竟我自己也是一路摸爬滚打过来的。

算法技巧就那几个套路,如果你心里有数,就会轻松很多,本文就来扒一扒动态规划的裤子,形成一套解决这类问题的思维框架。废话不多说了,上干货。

动态规划问题的一般形式就是求最值。动态规划其实是运筹学的一种最优化方法,只不过在计算机问题上应用比较多,比如说让你求最长递增子序列呀,最小编辑距离呀等等。

既然是要求最值,核心问题是什么呢?求解动态规划的核心问题是穷举。因为要求最值,肯定要把所有可行的答案穷举出来,然后在其中找最值呗。

动态规划就这么简单,就是穷举就完事了?我看到的动态规划问题都很难啊!

首先,动态规划的穷举有点特别,因为这类问题存在「重叠子问题」,如果暴力穷举的话效率会极其低下,所以需要「备忘录」或者「DP table」来优化穷举过程,避免不必要的计算。

而且,动态规划问题一定会具备「最优子结构」,才能通过子问题的最值得到原问题的最值。

另外,虽然动态规划的核心思想就是穷举求最值,但是问题可以千变万化,穷举所有可行解其实并不是一件容易的事,只有列出正确的「状态转移方程」才能正确地穷举。

以上提到的重叠子问题、最优子结构、状态转移方程就是动态规划三要素。具体什么意思等会会举例详解,但是在实际的算法问题中,写出状态转移方程是最困难的,这也就是为什么很多朋友觉得动态规划问题困难的原因,我来提供我研究出来的一个思维框架,辅助你思考状态转移方程:

明确「状态」 -> 定义 dp 数组/函数的含义 -> 明确「选择」-> 明确 base case。

下面通过斐波那契数列问题和凑零钱问题来详解动态规划的基本原理。前者主要是让你明白什么是重叠子问题(斐波那契数列严格来说不是动态规划问题),后者主要举集中于如何列出状态转移方程。

请读者不要嫌弃这个例子简单,只有简单的例子才能让你把精力充分集中在算法背后的通用思想和技巧上,而不会被那些隐晦的细节问题搞的莫名其妙。想要困难的例子,历史文章里有的是。

一、斐波那契数列

1、暴力递归

斐波那契数列的数学形式就是递归的,写成代码就是这样:

1
2
3
4
int fib(int N) {
if (N == 1 || N == 2) return 1;
return fib(N - 1) + fib(N - 2);
}

这个不用多说了,学校老师讲递归的时候似乎都是拿这个举例。我们也知道这样写代码虽然简洁易懂,但是十分低效,低效在哪里?假设 n = 20,请画出递归树。

PS:但凡遇到需要递归的问题,最好都画出递归树,这对你分析算法的复杂度,寻找算法低效的原因都有巨大帮助。

这个递归树怎么理解?就是说想要计算原问题 f(20),我就得先计算出子问题 f(19)f(18),然后要计算 f(19),我就要先算出子问题 f(18)f(17),以此类推。最后遇到 f(1) 或者 f(2) 的时候,结果已知,就能直接返回结果,递归树不再向下生长了。

递归算法的时间复杂度怎么计算?子问题个数乘以解决一个子问题需要的时间。

子问题个数,即递归树中节点的总数。显然二叉树节点总数为指数级别,所以子问题个数为 O(2^n)。

解决一个子问题的时间,在本算法中,没有循环,只有 f(n - 1) + f(n - 2) 一个加法操作,时间为 O(1)。

所以,这个算法的时间复杂度为 O(2^n),指数级别,爆炸。

观察递归树,很明显发现了算法低效的原因:存在大量重复计算,比如 f(18) 被计算了两次,而且你可以看到,以 f(18) 为根的这个递归树体量巨大,多算一遍,会耗费巨大的时间。更何况,还不止 f(18) 这一个节点被重复计算,所以这个算法及其低效。

这就是动态规划问题的第一个性质:重叠子问题。下面,我们想办法解决这个问题。

2、带备忘录的递归解法

明确了问题,其实就已经把问题解决了一半。即然耗时的原因是重复计算,那么我们可以造一个「备忘录」,每次算出某个子问题的答案后别急着返回,先记到「备忘录」里再返回;每次遇到一个子问题先去「备忘录」里查一查,如果发现之前已经解决过这个问题了,直接把答案拿出来用,不要再耗时去计算了。

一般使用一个数组充当这个「备忘录」,当然你也可以使用哈希表(字典),思想都是一样的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
int fib(int N) {
if (N < 1) return 0;
// 备忘录全初始化为 0
vector<int> memo(N + 1, 0);
// 初始化最简情况
return helper(memo, N);
}

int helper(vector<int>& memo, int n) {
// base case
if (n == 1 || n == 2) return 1;
// 已经计算过
if (memo[n] != 0) return memo[n];
memo[n] = helper(memo, n - 1) +
helper(memo, n - 2);
return memo[n];
}

现在,画出递归树,你就知道「备忘录」到底做了什么。

实际上,带「备忘录」的递归算法,把一棵存在巨量冗余的递归树通过「剪枝」,改造成了一幅不存在冗余的递归图,极大减少了子问题(即递归图中节点)的个数。

递归算法的时间复杂度怎么算?子问题个数乘以解决一个子问题需要的时间。

子问题个数,即图中节点的总数,由于本算法不存在冗余计算,子问题就是 f(1), f(2), f(3)f(20),数量和输入规模 n = 20 成正比,所以子问题个数为 O(n)。

解决一个子问题的时间,同上,没有什么循环,时间为 O(1)。

所以,本算法的时间复杂度是 O(n)。比起暴力算法,是降维打击。

至此,带备忘录的递归解法的效率已经和迭代的动态规划解法一样了。实际上,这种解法和迭代的动态规划已经差不多了,只不过这种方法叫做「自顶向下」,动态规划叫做「自底向上」。

啥叫「自顶向下」?注意我们刚才画的递归树(或者说图),是从上向下延伸,都是从一个规模较大的原问题比如说 f(20),向下逐渐分解规模,直到 f(1)f(2) 触底,然后逐层返回答案,这就叫「自顶向下」。

啥叫「自底向上」?反过来,我们直接从最底下,最简单,问题规模最小的 f(1)f(2) 开始往上推,直到推到我们想要的答案 f(20),这就是动态规划的思路,这也是为什么动态规划一般都脱离了递归,而是由循环迭代完成计算。

3、dp 数组的迭代解法

有了上一步「备忘录」的启发,我们可以把这个「备忘录」独立出来成为一张表,就叫做 DP table 吧,在这张表上完成「自底向上」的推算岂不美哉!

1
2
3
4
5
6
7
8
int fib(int N) {
vector<int> dp(N + 1, 0);
// base case
dp[1] = dp[2] = 1;
for (int i = 3; i <= N; i++)
dp[i] = dp[i - 1] + dp[i - 2];
return dp[N];
}

画个图就很好理解了,而且你发现这个 DP table 特别像之前那个「剪枝」后的结果,只是反过来算而已。实际上,带备忘录的递归解法中的「备忘录」,最终完成后就是这个 DP table,所以说这两种解法其实是差不多的,大部分情况下,效率也基本相同。

这里,引出「状态转移方程」这个名词,实际上就是描述问题结构的数学形式:

为啥叫「状态转移方程」?为了听起来高端。你把 f(n) 想做一个状态 n,这个状态 n 是由状态 n - 1 和状态 n - 2 相加转移而来,这就叫状态转移,仅此而已。

你会发现,上面的几种解法中的所有操作,例如 return f(n - 1) + f(n - 2),dp[i] = dp[i - 1] + dp[i - 2],以及对备忘录或 DP table 的初始化操作,都是围绕这个方程式的不同表现形式。可见列出「状态转移方程」的重要性,它是解决问题的核心。很容易发现,其实状态转移方程直接代表着暴力解法。

千万不要看不起暴力解,动态规划问题最困难的就是写出状态转移方程,即这个暴力解。优化方法无非是用备忘录或者 DP table,再无奥妙可言。

这个例子的最后,讲一个细节优化。细心的读者会发现,根据斐波那契数列的状态转移方程,当前状态只和之前的两个状态有关,其实并不需要那么长的一个 DP table 来存储所有的状态,只要想办法存储之前的两个状态就行了。所以,可以进一步优化,把空间复杂度降为 O(1):

1
2
3
4
5
6
7
8
9
10
11
int fib(int n) {
if (n == 2 || n == 1)
return 1;
int prev = 1, curr = 1;
for (int i = 3; i <= n; i++) {
int sum = prev + curr;
prev = curr;
curr = sum;
}
return curr;
}

有人会问,动态规划的另一个重要特性「最优子结构」,怎么没有涉及?下面会涉及。斐波那契数列的例子严格来说不算动态规划,因为没有涉及求最值,以上旨在演示算法设计螺旋上升的过程。下面,看第二个例子,凑零钱问题。

二、凑零钱问题

先看下题目:给你 k 种面值的硬币,面值分别为 c1, c2 ... ck,每种硬币的数量无限,再给一个总金额 amount,问你最少需要几枚硬币凑出这个金额,如果不可能凑出,算法返回 -1 。算法的函数签名如下:

1
2
// coins 中是可选硬币面值,amount 是目标金额
int coinChange(int[] coins, int amount);

比如说 k = 3,面值分别为 1,2,5,总金额 amount = 11。那么最少需要 3 枚硬币凑出,即 11 = 5 + 5 + 1。

你认为计算机应该如何解决这个问题?显然,就是把所有肯能的凑硬币方法都穷举出来,然后找找看最少需要多少枚硬币。

1、暴力递归

首先,这个问题是动态规划问题,因为它具有「最优子结构」的。要符合「最优子结构」,子问题间必须互相独立。啥叫相互独立?你肯定不想看数学证明,我用一个直观的例子来讲解。

比如说,你的原问题是考出最高的总成绩,那么你的子问题就是要把语文考到最高,数学考到最高…… 为了每门课考到最高,你要把每门课相应的选择题分数拿到最高,填空题分数拿到最高…… 当然,最终就是你每门课都是满分,这就是最高的总成绩。

得到了正确的结果:最高的总成绩就是总分。因为这个过程符合最优子结构,“每门科目考到最高”这些子问题是互相独立,互不干扰的。

但是,如果加一个条件:你的语文成绩和数学成绩会互相制约,此消彼长。这样的话,显然你能考到的最高总成绩就达不到总分了,按刚才那个思路就会得到错误的结果。因为子问题并不独立,语文数学成绩无法同时最优,所以最优子结构被破坏。

回到凑零钱问题,为什么说它符合最优子结构呢?比如你想求 amount = 11 时的最少硬币数(原问题),如果你知道凑出 amount = 10 的最少硬币数(子问题),你只需要把子问题的答案加一(再选一枚面值为 1 的硬币)就是原问题的答案,因为硬币的数量是没有限制的,子问题之间没有相互制,是互相独立的。

那么,既然知道了这是个动态规划问题,就要思考如何列出正确的状态转移方程

先确定「状态」,也就是原问题和子问题中变化的变量。由于硬币数量无限,所以唯一的状态就是目标金额 amount

然后确定 dp 函数的定义:当前的目标金额是 n,至少需要 dp(n) 个硬币凑出该金额。

然后确定「选择」并择优,也就是对于每个状态,可以做出什么选择改变当前状态。具体到这个问题,无论当的目标金额是多少,选择就是从面额列表 coins 中选择一个硬币,然后目标金额就会减少:

1
2
3
4
5
6
7
8
9
10
# 伪码框架
def coinChange(coins: List[int], amount: int):
# 定义:要凑出金额 n,至少要 dp(n) 个硬币
def dp(n):
# 做选择,选择需要硬币最少的那个结果
for coin in coins:
res = min(res, 1 + dp(n - coin))
return res
# 我们要求的问题是 dp(amount)
return dp(amount)

最后明确 base case,显然目标金额为 0 时,所需硬币数量为 0;当目标金额小于 0 时,无解,返回 -1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def coinChange(coins: List[int], amount: int):

def dp(n):
# base case
if n == 0: return 0
if n < 0: return -1
# 求最小值,所以初始化为正无穷
res = float('INF')
for coin in coins:
subproblem = dp(n - coin)
# 子问题无解,跳过
if subproblem == -1: continue
res = min(res, 1 + subproblem)

return res if res != float('INF') else -1

return dp(amount)

至此,状态转移方程其实已经完成了,以上算法已经是暴力解法了,以上代码的数学形式就是状态转移方程:

至此,这个问题其实就解决了,只不过需要消除一下重叠子问题,比如 amount = 11, coins = {1,2,5} 时画出递归树看看:

时间复杂度分析:子问题总数 x 每个子问题的时间

子问题总数为递归树节点个数,这个比较难看出来,是 O(n^k),总之是指数级别的。每个子问题中含有一个 for 循环,复杂度为 O(k)。所以总时间复杂度为 O(k * n^k),指数级别。

2、带备忘录的递归

只需要稍加修改,就可以通过备忘录消除子问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def coinChange(coins: List[int], amount: int):
# 备忘录
memo = dict()
def dp(n):
# 查备忘录,避免重复计算
if n in memo: return memo[n]

if n == 0: return 0
if n < 0: return -1
res = float('INF')
for coin in coins:
subproblem = dp(n - coin)
if subproblem == -1: continue
res = min(res, 1 + subproblem)

# 记入备忘录
memo[n] = res if res != float('INF') else -1
return memo[n]

return dp(amount)

不画图了,很显然「备忘录」大大减小了子问题数目,完全消除了子问题的冗余,所以子问题总数不会超过金额数 n,即子问题数目为 O(n)。处理一个子问题的时间不变,仍是 O(k),所以总的时间复杂度是 O(kn)。

3、dp 数组的迭代解法

当然,我们也可以自底向上使用 dp table 来消除重叠子问题,dp 数组的定义和刚才 dp 函数类似,定义也是一样的:

dp[i] = x 表示,当目标金额为 i 时,至少需要 x 枚硬币

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int coinChange(vector<int>& coins, int amount) {
// 数组大小为 amount + 1,初始值也为 amount + 1
vector<int> dp(amount + 1, amount + 1);
// base case
dp[0] = 0;
for (int i = 0; i < dp.size(); i++) {
// 内层 for 在求所有子问题 + 1 的最小值
for (int coin : coins) {
// 子问题无解,跳过
if (i - coin < 0) continue;
dp[i] = min(dp[i], 1 + dp[i - coin]);
}
}
return (dp[amount] == amount + 1) ? -1 : dp[amount];
}

PS:为啥 dp 数组初始化为 amount + 1 呢,因为凑成 amount 金额的硬币数最多只可能等于 amount(全用 1 元面值的硬币),所以初始化为 amount + 1 就相当于初始化为正无穷,便于后续取最小值。

三、最后总结

第一个斐波那契数列的问题,解释了如何通过「备忘录」或者「dp table」的方法来优化递归树,并且明确了这两种方法本质上是一样的,只是自顶向下和自底向上的不同而已。

第二个凑零钱的问题,展示了如何流程化确定「状态转移方程」,只要通过状态转移方程写出暴力递归解,剩下的也就是优化递归树,消除重叠子问题而已。

如果你不太了解动态规划,还能看到这里,真得给你鼓掌,相信你已经掌握了这个算法的设计技巧。

计算机解决问题其实没有任何奇技淫巧,它唯一的解决办法就是穷举,穷举所有可能性。算法设计无非就是先思考“如何穷举”,然后再追求“如何聪明地穷举”。

列出动态转移方程,就是在解决“如何穷举”的问题。之所以说它难,一是因为很多穷举需要递归实现,二是因为有的问题本身的解空间复杂,不那么容易穷举完整。

备忘录、DP table 就是在追求“如何聪明地穷举”。用空间换时间的思路,是降低时间复杂度的不二法门,除此之外,试问,还能玩出啥花活?

上一篇:学习数据结构和算法读什么书

下一篇:动态规划答疑篇

目录

title: 团灭股票问题
author: 远方
tags:

  • LeetCode
  • 算法
    categories:
  • LeetCode破局攻略
    date: 2016-01-01 19:20:00

团灭 LeetCode 股票买卖问题

很多读者抱怨 LeetCode 的股票系列问题奇技淫巧太多,如果面试真的遇到这类问题,基本不会想到那些巧妙的办法,怎么办?所以本文拒绝奇技淫巧,而是稳扎稳打,只用一种通用方法解决所用问题,以不变应万变
这篇文章用状态机的技巧来解决,可以全部提交通过。不要觉得这个名词高大上,文学词汇而已,实际上就是 DP table,看一眼就明白了。
PS:本文参考自英文版 LeetCode 的一篇题解
先随便抽出一道题,看看别人的解法:

1
2
3
4
5
6
7
8
9
10
11
12
int maxProfit(vector<int>& prices) {
if(prices.empty()) return 0;
int s1=-prices[0],s2=INT_MIN,s3=INT_MIN,s4=INT_MIN;

for(int i=1;i<prices.size();++i) {
s1 = max(s1, -prices[i]);
s2 = max(s2, s1+prices[i]);
s3 = max(s3, s2-prices[i]);
s4 = max(s4, s3+prices[i]);
}
return max(0,s4);
}

能看懂吧?会做了吗?不可能的,你看不懂,这才正常。就算你勉强看懂了,下一个问题你还是做不出来。为什么别人能写出这么诡异却又高效的解法呢?因为这类问题是有框架的,但是人家不会告诉你的,因为一旦告诉你,你五分钟就学会了,该算法题就不再神秘,变得不堪一击了。
本文就来告诉你这个框架,然后带着你一道一道秒杀。这篇文章用状态机的技巧来解决,可以全部提交通过。不要觉得这个名词高大上,文学词汇而已,实际上就是 DP table,看一眼就明白了。
这 6 道题目是有共性的,我就抽出来第 4 道题目,因为这道题是一个最泛化的形式,其他的问题都是这个形式的简化,看下题目:

第一题是只进行一次交易,相当于 k = 1;第二题是不限交易次数,相当于 k = +infinity(正无穷);第三题是只进行 2 次交易,相当于 k = 2;剩下两道也是不限次数,但是加了交易「冷冻期」和「手续费」的额外条件,其实就是第二题的变种,都很容易处理。
如果你还不熟悉题目,可以去 LeetCode 查看这些题目的内容,本文为了节省篇幅,就不列举这些题目的具体内容了。下面言归正传,开始解题。
一、穷举框架
首先,还是一样的思路:如何穷举?这里的穷举思路和上篇文章递归的思想不太一样。
递归其实是符合我们思考的逻辑的,一步步推进,遇到无法解决的就丢给递归,一不小心就做出来了,可读性还很好。缺点就是一旦出错,你也不容易找到错误出现的原因。比如上篇文章的递归解法,肯定还有计算冗余,但确实不容易找到。
而这里,我们不用递归思想进行穷举,而是利用「状态」进行穷举。我们具体到每一天,看看总共有几种可能的「状态」,再找出每个「状态」对应的「选择」。我们要穷举所有「状态」,穷举的目的是根据对应的「选择」更新状态。听起来抽象,你只要记住「状态」和「选择」两个词就行,下面实操一下就很容易明白了。

1
2
3
4
for 状态1 in 状态1的所有取值:
for 状态2 in 状态2的所有取值:
for ...
dp[状态1][状态2][...] = 择优(选择1,选择2...)

比如说这个问题,每天都有三种「选择」:买入、卖出、无操作,我们用 buy, sell, rest 表示这三种选择。但问题是,并不是每天都可以任意选择这三种选择的,因为 sell 必须在 buy 之后,buy 必须在 sell 之后。那么 rest 操作还应该分两种状态,一种是 buy 之后的 rest(持有了股票),一种是 sell 之后的 rest(没有持有股票)。而且别忘了,我们还有交易次数 k 的限制,就是说你 buy 还只能在 k > 0 的前提下操作。
很复杂对吧,不要怕,我们现在的目的只是穷举,你有再多的状态,老夫要做的就是一把梭全部列举出来。这个问题的「状态」有三个,第一个是天数,第二个是允许交易的最大次数,第三个是当前的持有状态(即之前说的 rest 的状态,我们不妨用 1 表示持有,0 表示没有持有)。然后我们用一个三维数组就可以装下这几种状态的全部组合:

1
2
3
4
5
6
7
8
dp[i][k][0 or 1]
0 <= i <= n-1, 1 <= k <= K
n 为天数,大 K 为最多交易数
此问题共 n × K × 2 种状态,全部穷举就能搞定。
for 0 <= i < n:
for 1 <= k <= K:
for s in {0, 1}:
dp[i][k][s] = max(buy, sell, rest)

而且我们可以用自然语言描述出每一个状态的含义,比如说 dp[3][2][1] 的含义就是:今天是第三天,我现在手上持有着股票,至今最多进行 2 次交易。再比如 dp[2][3][0] 的含义:今天是第二天,我现在手上没有持有股票,至今最多进行 3 次交易。很容易理解,对吧?
我们想求的最终答案是 dp[n - 1][K][0],即最后一天,最多允许 K 次交易,最多获得多少利润。读者可能问为什么不是 dp[n - 1][K][1]?因为 [1] 代表手上还持有股票,[0] 表示手上的股票已经卖出去了,很显然后者得到的利润一定大于前者。
记住如何解释「状态」,一旦你觉得哪里不好理解,把它翻译成自然语言就容易理解了。
二、状态转移框架
现在,我们完成了「状态」的穷举,我们开始思考每种「状态」有哪些「选择」,应该如何更新「状态」。只看「持有状态」,可以画个状态转移图。

通过这个图可以很清楚地看到,每种状态(0 和 1)是如何转移而来的。根据这个图,我们来写一下状态转移方程:

1
2
3
4
5
6
7
8
9
10
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1] + prices[i])
max( 选择 rest , 选择 sell )
解释:今天我没有持有股票,有两种可能:
要么是我昨天就没有持有,然后今天选择 rest,所以我今天还是没有持有;
要么是我昨天持有股票,但是今天我 sell 了,所以我今天没有持有股票了。
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0] - prices[i])
max( 选择 rest , 选择 buy )
解释:今天我持有着股票,有两种可能:
要么我昨天就持有着股票,然后今天选择 rest,所以我今天还持有着股票;
要么我昨天本没有持有,但今天我选择 buy,所以今天我就持有股票了。

这个解释应该很清楚了,如果 buy,就要从利润中减去 prices[i],如果 sell,就要给利润增加 prices[i]。今天的最大利润就是这两种可能选择中较大的那个。而且注意 k 的限制,我们在选择 buy 的时候,把 k 减小了 1,很好理解吧,当然你也可以在 sell 的时候减 1,一样的。
现在,我们已经完成了动态规划中最困难的一步:状态转移方程。如果之前的内容你都可以理解,那么你已经可以秒杀所有问题了,只要套这个框架就行了。不过还差最后一点点,就是定义 base case,即最简单的情况。

1
2
3
4
5
6
7
8
dp[-1][k][0] = 0
解释:因为 i 是从 0 开始的,所以 i = -1 意味着还没有开始,这时候的利润当然是 0 。
dp[-1][k][1] = -infinity
解释:还没开始的时候,是不可能持有股票的,用负无穷表示这种不可能。
dp[i][0][0] = 0
解释:因为 k 是从 1 开始的,所以 k = 0 意味着根本不允许交易,这时候利润当然是 0 。
dp[i][0][1] = -infinity
解释:不允许交易的情况下,是不可能持有股票的,用负无穷表示这种不可能。

把上面的状态转移方程总结一下:

1
2
3
4
5
6
base case:
dp[-1][k][0] = dp[i][0][0] = 0
dp[-1][k][1] = dp[i][0][1] = -infinity
状态转移方程:
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1] + prices[i])
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0] - prices[i])

读者可能会问,这个数组索引是 -1 怎么编程表示出来呢,负无穷怎么表示呢?这都是细节问题,有很多方法实现。现在完整的框架已经完成,下面开始具体化。
三、秒杀题目
第一题,k = 1
直接套状态转移方程,根据 base case,可以做一些化简:

1
2
3
4
5
6
7
8
dp[i][1][0] = max(dp[i-1][1][0], dp[i-1][1][1] + prices[i])
dp[i][1][1] = max(dp[i-1][1][1], dp[i-1][0][0] - prices[i])
= max(dp[i-1][1][1], -prices[i])
解释:k = 0 的 base case,所以 dp[i-1][0][0] = 0。
现在发现 k 都是 1,不会改变,即 k 对状态转移已经没有影响了。
可以进行进一步化简去掉所有 k:
dp[i][0] = max(dp[i-1][0], dp[i-1][1] + prices[i])
dp[i][1] = max(dp[i-1][1], -prices[i])

直接写出代码:

1
2
3
4
5
6
7
int n = prices.length;
int[][] dp = new int[n][2];
for (int i = 0; i < n; i++) {
dp[i][0] = Math.max(dp[i-1][0], dp[i-1][1] + prices[i]);
dp[i][1] = Math.max(dp[i-1][1], -prices[i]);
}
return dp[n - 1][0];

显然 i = 0 时 dp[i-1] 是不合法的。这是因为我们没有对 i 的 base case 进行处理。可以这样处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
for (int i = 0; i < n; i++) {
if (i - 1 == -1) {
dp[i][0] = 0;
// 解释:
// dp[i][0]
// = max(dp[-1][0], dp[-1][1] + prices[i])
// = max(0, -infinity + prices[i]) = 0
dp[i][1] = -prices[i];
//解释:
// dp[i][1]
// = max(dp[-1][1], dp[-1][0] - prices[i])
// = max(-infinity, 0 - prices[i])
// = -prices[i]
continue;
}
dp[i][0] = Math.max(dp[i-1][0], dp[i-1][1] + prices[i]);
dp[i][1] = Math.max(dp[i-1][1], -prices[i]);
}
return dp[n - 1][0];

第一题就解决了,但是这样处理 base case 很麻烦,而且注意一下状态转移方程,新状态只和相邻的一个状态有关,其实不用整个 dp 数组,只需要一个变量储存相邻的那个状态就足够了,这样可以把空间复杂度降到 O(1):

1
2
3
4
5
6
7
8
9
10
11
12
13
// k == 1
int maxProfit_k_1(int[] prices) {
int n = prices.length;
// base case: dp[-1][0] = 0, dp[-1][1] = -infinity
int dp_i_0 = 0, dp_i_1 = Integer.MIN_VALUE;
for (int i = 0; i < n; i++) {
// dp[i][0] = max(dp[i-1][0], dp[i-1][1] + prices[i])
dp_i_0 = Math.max(dp_i_0, dp_i_1 + prices[i]);
// dp[i][1] = max(dp[i-1][1], -prices[i])
dp_i_1 = Math.max(dp_i_1, -prices[i]);
}
return dp_i_0;
}

两种方式都是一样的,不过这种编程方法简洁很多。但是如果没有前面状态转移方程的引导,是肯定看不懂的。后续的题目,我主要写这种空间复杂度 O(1) 的解法。
第二题,k = +infinity
如果 k 为正无穷,那么就可以认为 k 和 k - 1 是一样的。可以这样改写框架:

1
2
3
4
5
6
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1] + prices[i])
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0] - prices[i])
= max(dp[i-1][k][1], dp[i-1][k][0] - prices[i])
我们发现数组中的 k 已经不会改变了,也就是说不需要记录 k 这个状态了:
dp[i][0] = max(dp[i-1][0], dp[i-1][1] + prices[i])
dp[i][1] = max(dp[i-1][1], dp[i-1][0] - prices[i])

直接翻译成代码:

1
2
3
4
5
6
7
8
9
10
int maxProfit_k_inf(int[] prices) {
int n = prices.length;
int dp_i_0 = 0, dp_i_1 = Integer.MIN_VALUE;
for (int i = 0; i < n; i++) {
int temp = dp_i_0;
dp_i_0 = Math.max(dp_i_0, dp_i_1 + prices[i]);
dp_i_1 = Math.max(dp_i_1, temp - prices[i]);
}
return dp_i_0;
}

第三题,k = +infinity with cooldown
每次 sell 之后要等一天才能继续交易。只要把这个特点融入上一题的状态转移方程即可:

1
2
3
dp[i][0] = max(dp[i-1][0], dp[i-1][1] + prices[i])
dp[i][1] = max(dp[i-1][1], dp[i-2][0] - prices[i])
解释:第 i 天选择 buy 的时候,要从 i-2 的状态转移,而不是 i-1 。

翻译成代码:

1
2
3
4
5
6
7
8
9
10
11
12
int maxProfit_with_cool(int[] prices) {
int n = prices.length;
int dp_i_0 = 0, dp_i_1 = Integer.MIN_VALUE;
int dp_pre_0 = 0; // 代表 dp[i-2][0]
for (int i = 0; i < n; i++) {
int temp = dp_i_0;
dp_i_0 = Math.max(dp_i_0, dp_i_1 + prices[i]);
dp_i_1 = Math.max(dp_i_1, dp_pre_0 - prices[i]);
dp_pre_0 = temp;
}
return dp_i_0;
}

第四题,k = +infinity with fee
每次交易要支付手续费,只要把手续费从利润中减去即可。改写方程:

1
2
3
4
dp[i][0] = max(dp[i-1][0], dp[i-1][1] + prices[i])
dp[i][1] = max(dp[i-1][1], dp[i-1][0] - prices[i] - fee)
解释:相当于买入股票的价格升高了。
在第一个式子里减也是一样的,相当于卖出股票的价格减小了。

直接翻译成代码:

1
2
3
4
5
6
7
8
9
10
int maxProfit_with_fee(int[] prices, int fee) {
int n = prices.length;
int dp_i_0 = 0, dp_i_1 = Integer.MIN_VALUE;
for (int i = 0; i < n; i++) {
int temp = dp_i_0;
dp_i_0 = Math.max(dp_i_0, dp_i_1 + prices[i]);
dp_i_1 = Math.max(dp_i_1, temp - prices[i] - fee);
}
return dp_i_0;
}

第五题,k = 2
k = 2 和前面题目的情况稍微不同,因为上面的情况都和 k 的关系不太大。要么 k 是正无穷,状态转移和 k 没关系了;要么 k = 1,跟 k = 0 这个 base case 挨得近,最后也没有存在感。
这道题 k = 2 和后面要讲的 k 是任意正整数的情况中,对 k 的处理就凸显出来了。我们直接写代码,边写边分析原因。

1
2
3
原始的动态转移方程,没有可化简的地方
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1] + prices[i])
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0] - prices[i])

按照之前的代码,我们可能想当然这样写代码(错误的):

1
2
3
4
5
6
7
8
int k = 2;
int[][][] dp = new int[n][k + 1][2];
for (int i = 0; i < n; i++)
if (i - 1 == -1) { /* 处理一下 base case*/ }
dp[i][k][0] = Math.max(dp[i-1][k][0], dp[i-1][k][1] + prices[i]);
dp[i][k][1] = Math.max(dp[i-1][k][1], dp[i-1][k-1][0] - prices[i]);
}
return dp[n - 1][k][0];

为什么错误?我这不是照着状态转移方程写的吗?
还记得前面总结的「穷举框架」吗?就是说我们必须穷举所有状态。其实我们之前的解法,都在穷举所有状态,只是之前的题目中 k 都被化简掉了。比如说第一题,k = 1:
「代码截图」
这道题由于没有消掉 k 的影响,所以必须要对 k 进行穷举:

1
2
3
4
5
6
7
8
9
10
11
int max_k = 2;
int[][][] dp = new int[n][max_k + 1][2];
for (int i = 0; i < n; i++) {
for (int k = max_k; k >= 1; k--) {
if (i - 1 == -1) { /*处理 base case */ }
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1] + prices[i]);
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0] - prices[i]);
}
}
// 穷举了 n × max_k × 2 个状态,正确。
return dp[n - 1][max_k][0];

如果你不理解,可以返回第一点「穷举框架」重新阅读体会一下。
这里 k 取值范围比较小,所以可以不用 for 循环,直接把 k = 1 和 2 的情况全部列举出来也可以:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
dp[i][2][0] = max(dp[i-1][2][0], dp[i-1][2][1] + prices[i])
dp[i][2][1] = max(dp[i-1][2][1], dp[i-1][1][0] - prices[i])
dp[i][1][0] = max(dp[i-1][1][0], dp[i-1][1][1] + prices[i])
dp[i][1][1] = max(dp[i-1][1][1], -prices[i])
int maxProfit_k_2(int[] prices) {
int dp_i10 = 0, dp_i11 = Integer.MIN_VALUE;
int dp_i20 = 0, dp_i21 = Integer.MIN_VALUE;
for (int price : prices) {
dp_i20 = Math.max(dp_i20, dp_i21 + price);
dp_i21 = Math.max(dp_i21, dp_i10 - price);
dp_i10 = Math.max(dp_i10, dp_i11 + price);
dp_i11 = Math.max(dp_i11, -price);
}
return dp_i20;
}

有状态转移方程和含义明确的变量名指导,相信你很容易看懂。其实我们可以故弄玄虚,把上述四个变量换成 a, b, c, d。这样当别人看到你的代码时就会大惊失色,对你肃然起敬。
第六题,k = any integer
有了上一题 k = 2 的铺垫,这题应该和上一题的第一个解法没啥区别。但是出现了一个超内存的错误,原来是传入的 k 值会非常大,dp 数组太大了。现在想想,交易次数 k 最多有多大呢?
一次交易由买入和卖出构成,至少需要两天。所以说有效的限制 k 应该不超过 n/2,如果超过,就没有约束作用了,相当于 k = +infinity。这种情况是之前解决过的。
直接把之前的代码重用:

1
2
3
4
5
6
7
8
9
10
11
12
13
int maxProfit_k_any(int max_k, int[] prices) {
int n = prices.length;
if (max_k > n / 2)
return maxProfit_k_inf(prices);
int[][][] dp = new int[n][max_k + 1][2];
for (int i = 0; i < n; i++)
for (int k = max_k; k >= 1; k--) {
if (i - 1 == -1) { /* 处理 base case */ }
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1] + prices[i]);
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0] - prices[i]);
}
return dp[n - 1][max_k][0];
}

至此,6 道题目通过一个状态转移方程全部解决。

四、最后总结
本文给大家讲了如何通过状态转移的方法解决复杂的问题,用一个状态转移方程秒杀了 6 道股票买卖问题,现在想想,其实也不算难对吧?这已经属于动态规划问题中较困难的了。
关键就在于列举出所有可能的「状态」,然后想想怎么穷举更新这些「状态」。一般用一个多维 dp 数组储存这些状态,从 base case 开始向后推进,推进到最后的状态,就是我们想要的答案。想想这个过程,你是不是有点理解「动态规划」这个名词的意义了呢?
具体到股票买卖问题,我们发现了三个状态,使用了一个三维数组,无非还是穷举 + 更新,不过我们可以说的高大上一点,这叫「三维 DP」,怕不怕?这个大实话一说,立刻显得你高人一等,名利双收有没有,所以给个在看/分享吧,鼓励一下我。

上一篇:动态规划之KMP字符匹配算法
下一篇:团灭 LeetCode 打家劫舍问题
目录

二分查找详解

先给大家讲个笑话乐呵一下:
有一天阿东到图书馆借了 N 本书,出图书馆的时候,警报响了,于是保安把阿东拦下,要检查一下哪本书没有登记出借。阿东正准备把每一本书在报警器下过一下,以找出引发警报的书,但是保安露出不屑的眼神:你连二分查找都不会吗?于是保安把书分成两堆,让第一堆过一下报警器,报警器响;于是再把这堆书分成两堆…… 最终,检测了 logN 次之后,保安成功的找到了那本引起警报的书,露出了得意和嘲讽的笑容。于是阿东背着剩下的书走了。
从此,图书馆丢了 N - 1 本书。
二分查找并不简单,Knuth 大佬(发明 KMP 算法的那位)都说二分查找:思路很简单,细节是魔鬼。很多人喜欢拿整型溢出的 bug 说事儿,但是二分查找真正的坑根本就不是那个细节问题,而是在于到底要给 mid 加一还是减一,while 里到底用 <= 还是 <
你要是没有正确理解这些细节,写二分肯定就是玄学编程,有没有 bug 只能靠菩萨保佑。我特意写了一首诗来歌颂该算法,概括本文的主要内容,建议保存

本文就来探究几个最常用的二分查找场景:寻找一个数、寻找左侧边界、寻找右侧边界。而且,我们就是要深入细节,比如不等号是否应该带等号,mid 是否应该加一等等。分析这些细节的差异以及出现这些差异的原因,保证你能灵活准确地写出正确的二分查找算法。

零、二分查找框架

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int binarySearch(int[] nums, int target) {
int left = 0, right = ...;
while(...) {
int mid = left + (right - left) / 2;
if (nums[mid] == target) {
...
} else if (nums[mid] < target) {
left = ...
} else if (nums[mid] > target) {
right = ...
}
}
return ...;
}

分析二分查找的一个技巧是:不要出现 else,而是把所有情况用 else if 写清楚,这样可以清楚地展现所有细节。本文都会使用 else if,旨在讲清楚,读者理解后可自行简化。
其中 ... 标记的部分,就是可能出现细节问题的地方,当你见到一个二分查找的代码时,首先注意这几个地方。后文用实例分析这些地方能有什么样的变化。
另外声明一下,计算 mid 时需要防止溢出,代码中 left + (right - left) / 2 就和 (left + right) / 2 的结果相同,但是有效防止了 leftright 太大直接相加导致溢出。

一、寻找一个数(基本的二分搜索)

这个场景是最简单的,肯能也是大家最熟悉的,即搜索一个数,如果存在,返回其索引,否则返回 -1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int binarySearch(int[] nums, int target) {
int left = 0;
int right = nums.length - 1; // 注意
while(left <= right) {
int mid = left + (right - left) / 2;
if(nums[mid] == target)
return mid;
else if (nums[mid] < target)
left = mid + 1; // 注意
else if (nums[mid] > target)
right = mid - 1; // 注意
}
return -1;
}

1、为什么 while 循环的条件中是 <=,而不是 <**?
答:因为初始化 right 的赋值是 nums.length - 1,即最后一个元素的索引,而不是 nums.length
这二者可能出现在不同功能的二分查找中,区别是:前者相当于两端都闭区间 [left, right],后者相当于左闭右开区间 [left, right),因为索引大小为 nums.length 是越界的。
我们这个算法中使用的是前者 [left, right] 两端都闭的区间。
这个区间其实就是每次进行搜索的区间**。
什么时候应该停止搜索呢?当然,找到了目标值的时候可以终止:

1
2
if(nums[mid] == target)
return mid;

但如果没找到,就需要 while 循环终止,然后返回 -1。那 while 循环什么时候应该终止?搜索区间为空的时候应该终止,意味着你没得找了,就等于没找到嘛。
while(left <= right) 的终止条件是 left == right + 1,写成区间的形式就是 [right + 1, right],或者带个具体的数字进去 [3, 2],可见这时候区间为空,因为没有数字既大于等于 3 又小于等于 2 的吧。所以这时候 while 循环终止是正确的,直接返回 -1 即可。
while(left < right) 的终止条件是 left == right,写成区间的形式就是 [left, right],或者带个具体的数字进去 [2, 2]这时候区间非空,还有一个数 2,但此时 while 循环终止了。也就是说这区间 [2, 2] 被漏掉了,索引 2 没有被搜索,如果这时候直接返回 -1 就是错误的。
当然,如果你非要用 while(left < right) 也可以,我们已经知道了出错的原因,就打个补丁好了:

1
2
3
4
5
//...
while(left < right) {
// ...
}
return nums[left] == target ? left : -1;

2、为什么 left = mid + 1right = mid - 1?我看有的代码是 right = mid 或者 left = mid,没有这些加加减减,到底怎么回事,怎么判断
答:这也是二分查找的一个难点,不过只要你能理解前面的内容,就能够很容易判断。
刚才明确了「搜索区间」这个概念,而且本算法的搜索区间是两端都闭的,即 [left, right]。那么当我们发现索引 mid 不是要找的 target 时,下一步应该去搜索哪里呢?
当然是去搜索 [left, mid-1] 或者 [mid+1, right] 对不对?因为 mid 已经搜索过,应该从搜索区间中去除
3、此算法有什么缺陷
答:至此,你应该已经掌握了该算法的所有细节,以及这样处理的原因。但是,这个算法存在局限性。
比如说给你有序数组 nums = [1,2,2,2,3]target 为 2,此算法返回的索引是 2,没错。但是如果我想得到 target 的左侧边界,即索引 1,或者我想得到 target 的右侧边界,即索引 3,这样的话此算法是无法处理的。
这样的需求很常见,你也许会说,找到一个 target,然后向左或向右线性搜索不行吗?可以,但是不好,因为这样难以保证二分查找对数级的复杂度了
我们后续的算法就来讨论这两种二分查找的算法。

二、寻找左侧边界的二分搜索

以下是最常见的代码形式,其中的标记是需要注意的细节:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
int left_bound(int[] nums, int target) {
if (nums.length == 0) return -1;
int left = 0;
int right = nums.length; // 注意

while (left < right) { // 注意
int mid = (left + right) / 2;
if (nums[mid] == target) {
right = mid;
} else if (nums[mid] < target) {
left = mid + 1;
} else if (nums[mid] > target) {
right = mid; // 注意
}
}
return left;
}

1、为什么 while 中是 < 而不是 <=?
答:用相同的方法分析,因为 right = nums.length 而不是 nums.length - 1。因此每次循环的「搜索区间」是 [left, right) 左闭右开。
while(left < right) 终止的条件是 left == right,此时搜索区间 [left, left) 为空,所以可以正确终止。
PS:这里先要说一个搜索左右边界和上面这个算法的一个区别,也是很多读者问的:刚才的 right 不是 nums.length - 1 吗,为啥这里非要写成 nums.length 使得「搜索区间」变成左闭右开呢
因为对于搜索左右侧边界的二分查找,这种写法比较普遍,我就拿这种写法举例了,保证你以后遇到这类代码可以理解。你非要用两端都闭的写法反而更简单,我会在后面写相关的代码,把三种二分搜索都用一种两端都闭的写法统一起来,你耐心往后看就行了。
2、为什么没有返回 -1 的操作?如果 nums 中不存在 target 这个值,怎么办
答:因为要一步一步来,先理解一下这个「左侧边界」有什么特殊含义:

对于这个数组,算法会返回 1。这个 1 的含义可以这样解读:nums 中小于 2 的元素有 1 个。
比如对于有序数组 nums = [2,3,5,7], target = 1,算法会返回 0,含义是:nums 中小于 1 的元素有 0 个。
再比如说 nums = [2,3,5,7], target = 8,算法会返回 4,含义是:nums 中小于 8 的元素有 4 个。
综上可以看出,函数的返回值(即 left 变量的值)取值区间是闭区间 [0, nums.length],所以我们简单添加两行代码就能在正确的时候 return -1:

1
2
3
4
5
6
7
while (left < right) {
//...
}
// target 比所有数都大
if (left == nums.length) return -1;
// 类似之前算法的处理方式
return nums[left] == target ? left : -1;

3、为什么 left = mid + 1right = mid ?和之前的算法不一样
答:这个很好解释,因为我们的「搜索区间」是 [left, right) 左闭右开,所以当 nums[mid] 被检测之后,下一步的搜索区间应该去掉 mid 分割成两个区间,即 [left, mid)[mid + 1, right)
4、为什么该算法能够搜索左侧边界
答:关键在于对于 nums[mid] == target 这种情况的处理:

1
2
if (nums[mid] == target)
right = mid;

可见,找到 target 时不要立即返回,而是缩小「搜索区间」的上界 right,在区间 [left, mid) 中继续搜索,即不断向左收缩,达到锁定左侧边界的目的。
5、为什么返回 left 而不是 right**?
答:都是一样的,因为 while 终止的条件是 left == right
**6、能不能想办法把 right 变成 nums.length - 1,也就是继续使用两边都闭的「搜索区间」?这样就可以和第一种二分搜索在某种程度上统一起来了

答:当然可以,只要你明白了「搜索区间」这个概念,就能有效避免漏掉元素,随便你怎么改都行。下面我们严格根据逻辑来修改:
因为你非要让搜索区间两端都闭,所以 right 应该初始化为 nums.length - 1,while 的终止条件应该是 left == right + 1,也就是其中应该用 <=

1
2
3
4
5
6
7
int left_bound(int[] nums, int target) {
// 搜索区间为 [left, right]
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
// if else ...
}

因为搜索区间是两端都闭的,且现在是搜索左侧边界,所以 leftright 的更新逻辑如下:

1
2
3
4
5
6
7
8
9
10
if (nums[mid] < target) {
// 搜索区间变为 [mid+1, right]
left = mid + 1;
} else if (nums[mid] > target) {
// 搜索区间变为 [left, mid-1]
right = mid - 1;
} else if (nums[mid] == target) {
// 收缩右侧边界
right = mid - 1;
}

由于 while 的退出条件是 left == right + 1,所以当 targetnums 中所有元素都大时,会存在以下情况使得索引越界:

因此,最后返回结果的代码应该检查越界情况:

1
2
3
if (left >= nums.length || nums[left] != target)
return -1;
return left;

至此,整个算法就写完了,完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int left_bound(int[] nums, int target) {
int left = 0, right = nums.length - 1;
// 搜索区间为 [left, right]
while (left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] < target) {
// 搜索区间变为 [mid+1, right]
left = mid + 1;
} else if (nums[mid] > target) {
// 搜索区间变为 [left, mid-1]
right = mid - 1;
} else if (nums[mid] == target) {
// 收缩右侧边界
right = mid - 1;
}
}
// 检查出界情况
if (left >= nums.length || nums[left] != target)
return -1;
return left;
}

这样就和第一种二分搜索算法统一了,都是两端都闭的「搜索区间」,而且最后返回的也是 left 变量的值。只要把住二分搜索的逻辑,两种形式大家看自己喜欢哪种记哪种吧。

三、寻找右侧边界的二分查找

类似寻找左侧边界的算法,这里也会提供两种写法,还是先写常见的左闭右开的写法,只有两处和搜索左侧边界不同,已标注:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int right_bound(int[] nums, int target) {
if (nums.length == 0) return -1;
int left = 0, right = nums.length;

while (left < right) {
int mid = (left + right) / 2;
if (nums[mid] == target) {
left = mid + 1; // 注意
} else if (nums[mid] < target) {
left = mid + 1;
} else if (nums[mid] > target) {
right = mid;
}
}
return left - 1; // 注意
}

1、为什么这个算法能够找到右侧边界
答:类似地,关键点还是这里:

1
2
if (nums[mid] == target) {
left = mid + 1;

nums[mid] == target 时,不要立即返回,而是增大「搜索区间」的下界 left,使得区间不断向右收缩,达到锁定右侧边界的目的。
2、为什么最后返回 left - 1 而不像左侧边界的函数,返回 left?而且我觉得这里既然是搜索右侧边界,应该返回 right 才对
答:首先,while 循环的终止条件是 left == right,所以 leftright 是一样的,你非要体现右侧的特点,返回 right - 1 好了。
至于为什么要减一,这是搜索右侧边界的一个特殊点,关键在这个条件判断:

1
2
3
if (nums[mid] == target) {
left = mid + 1;
// 这样想: mid = left - 1


因为我们对 left 的更新必须是 left = mid + 1,就是说 while 循环结束时,nums[left] 一定不等于 target 了,而 nums[left-1] 可能是 target
至于为什么 left 的更新必须是 left = mid + 1,同左侧边界搜索,就不再赘述。
3、为什么没有返回 -1 的操作?如果 nums 中不存在 target 这个值,怎么办
答:类似之前的左侧边界搜索,因为 while 的终止条件是 left == right,就是说 left 的取值范围是 [0, nums.length],所以可以添加两行代码,正确地返回 -1:

1
2
3
4
5
while (left < right) {
// ...
}
if (left == 0) return -1;
return nums[left-1] == target ? (left-1) : -1;

4、是否也可以把这个算法的「搜索区间」也统一成两端都闭的形式呢?这样这三个写法就完全统一了,以后就可以闭着眼睛写出来了
答:当然可以,类似搜索左侧边界的统一写法,其实只要改两个地方就行了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int right_bound(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] < target) {
left = mid + 1;
} else if (nums[mid] > target) {
right = mid - 1;
} else if (nums[mid] == target) {
// 这里改成收缩左侧边界即可
left = mid + 1;
}
}
// 这里改为检查 right 越界的情况,见下图
if (right < 0 || nums[right] != target)
return -1;
return right;
}

target 比所有元素都小时,right 会被减到 -1,所以需要在最后防止越界:

至此,搜索右侧边界的二分查找的两种写法也完成了,其实将「搜索区间」统一成两端都闭反而更容易记忆,你说是吧?

四、逻辑统一

来梳理一下这些细节差异的因果逻辑:
第一个,最基本的二分查找算法

1
2
3
4
5
6
因为我们初始化 right = nums.length - 1
所以决定了我们的「搜索区间」是 [left, right]
所以决定了 while (left <= right)
同时也决定了 left = mid+1 和 right = mid-1
因为我们只需找到一个 target 的索引即可
所以当 nums[mid] == target 时可以立即返回

第二个,寻找左侧边界的二分查找

1
2
3
4
5
6
7
因为我们初始化 right = nums.length
所以决定了我们的「搜索区间」是 [left, right)
所以决定了 while (left < right)
同时也决定了 left = mid + 1 和 right = mid
因为我们需找到 target 的最左侧索引
所以当 nums[mid] == target 时不要立即返回
而要收紧右侧边界以锁定左侧边界

第三个,寻找右侧边界的二分查找

1
2
3
4
5
6
7
8
9
因为我们初始化 right = nums.length
所以决定了我们的「搜索区间」是 [left, right)
所以决定了 while (left < right)
同时也决定了 left = mid + 1 和 right = mid
因为我们需找到 target 的最右侧索引
所以当 nums[mid] == target 时不要立即返回
而要收紧左侧边界以锁定右侧边界
又因为收紧左侧边界时必须 left = mid + 1
所以最后无论返回 left 还是 right,必须减一

对于寻找左右边界的二分搜索,常见的手法是使用左闭右开的「搜索区间」,我们还根据逻辑将「搜索区间」全都统一成了两端都闭,便于记忆,只要修改两处即可变化出三种写法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
int binary_search(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while(left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] < target) {
left = mid + 1;
} else if (nums[mid] > target) {
right = mid - 1;
} else if(nums[mid] == target) {
// 直接返回
return mid;
}
}
// 直接返回
return -1;
}
int left_bound(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] < target) {
left = mid + 1;
} else if (nums[mid] > target) {
right = mid - 1;
} else if (nums[mid] == target) {
// 别返回,锁定左侧边界
right = mid - 1;
}
}
// 最后要检查 left 越界的情况
if (left >= nums.length || nums[left] != target)
return -1;
return left;
}

int right_bound(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] < target) {
left = mid + 1;
} else if (nums[mid] > target) {
right = mid - 1;
} else if (nums[mid] == target) {
// 别返回,锁定右侧边界
left = mid + 1;
}
}
// 最后要检查 right 越界的情况
if (right < 0 || nums[right] != target)
return -1;
return right;
}

如果以上内容你都能理解,那么恭喜你,二分查找算法的细节不过如此。
通过本文,你学会了:
1、分析二分查找代码时,不要出现 else,全部展开成 else if 方便理解。
2、注意「搜索区间」和 while 的终止条件,如果存在漏掉的元素,记得在最后检查。
3、如需定义左闭右开的「搜索区间」搜索左右边界,只要在 nums[mid] == target 时做修改即可,搜索右侧时需要减一。
4、如果将「搜索区间」全都统一成两端都闭,好记,只要稍改 nums[mid] == target 条件处的代码和返回的逻辑即可,推荐拿小本本记下,作为二分搜索模板
呵呵,此文对二分查找的问题无敌好吧!

上一篇:回溯算法解题框架
下一篇:滑动窗口解题框架
目录