理解KMP算法:从入门到推导

By Long Luo

之前虽然知道 KMP 算法,但是一直无法深入理解其原理,最近看了 2 篇文章:从头到尾彻底理解KMP(2014年8月22日版)KMP算法教程 ,然后再实际写代码,终于对 \(\textit{KMP}\) 算法有了一定了解了,下面写一写我个人的学习过程。

这篇文章主要从我个人学习过程来叙述:

为了解决什么问题?

\(\textit{KMP}\) 算法是一种字符串匹配算法,可以在 \(O(n+m)\) 的时间复杂度内实现两个字符串的匹配。

所谓字符串匹配,是这样一种问题:

  1. 字符串 \(\textit{P}\) 是否为字符串 \(\textit{S}\) 的子串?
  2. 如果是,\(\textit{P}\) 出现在 \(\textit{S}\) 的哪些位置?

其中 \(\textit{S}\) 称为主串\(\textit{P}\) 称为模式串

最常见的就是经常要用的搜索功能,比如要在一大段文字中找到某个句子或者找到出现的全部位置。在这种场景下,要找的句子就是给定的模式 \(\textit{P}\),而大段文字就是目标字符串 \(\textit{S}\)

从暴力法开始

我们先从最直观的地方开始:

  1. 两个字符串 \(\textit{A}\)\(\textit{B}\) 的比较?
  2. 如果 \(\textit{P}\)\(\textit{S}\) 的字串,第一个出现的位置在哪里?

所谓字符串比较,就是问两个字符串是否相等

最朴素的思想,就是从前往后逐字符比较,一旦遇到不相同的字符,就返回 \(\textit{False}\) ;如果两个字符串都结束了,仍然没有出现不对应的字符,则返回 \(\textit{True}\)

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public static int Search_BruteForce(String targetStr, String patternStr) {
int targetLen = targetStr.length();
int patLen = patternStr.length();

for (int i = 0; i < targetLen; i++) {
if (targetStr.charAt(i) == patternStr.charAt(i)) {
for (int j = 1; j < patLen; j++) {
if (targetStr.charAt(i + j) != patternStr.charAt(j)) {
break;
}

if (j == patLen - 1) {
return i;
}
}

}
}

return -1;
}

或者双指针版:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public static int ViolentMatch(String targetStr, String patternStr) {
int targetLen = targetStr.length();
int patLen = patternStr.length();

int i = 0;
int j = 0;
while (i < targetLen && j < patLen) {
if (targetStr.charAt(i) == patternStr.charAt(j)) {
i++;
j++;
} else {
i = i - j + 1;
j = 0;
}
}

if (j == patLen) {
return i - j;
} else {
return -1;
}
}

暴力解法复杂度分析

刚才我们成功实现了暴力算法,那么其时间复杂度如何?

\(n\) 为串 \(\textit{S}\) 的长度,\(m\) 为串 \(\textit{P}\) 的长度。

考虑“字符串比较”这个小任务的复杂度。最坏情况发生在:两个字符串唯一的差别在最后一个字符。这种情况下,字符串比较必须走完整个字符串,才能给出结果,因此复杂度是 \(O(len(str))\) 的。

由此,不难想到暴力算法所面对的最坏情况:

主串形如 “AAAAAAAAAAAAA…B” ,而模式串形如 “AAAAA…B” 。每次字符串比较都需要付出 \(O(m)\) 次字符比较的代价,总共需要比较 \(n\) 次,因此总时间复杂度是 \(O(nm)\)

那么如何改进呢?

信息熵冗余

我们很难降低字符串比较的复杂度(因为比较两个字符串,真的只能逐个比较字符)。因此,我们考虑降低比较的趟数。如果比较的趟数能降到足够低,那么总的复杂度也将会下降很多。

要优化一个算法,首先要回答的问题是“我手上有什么信息?” 我们手上的信息是否足够、是否有效,决定了我们能把算法优化到何种程度。请记住:尽可能利用残余的信息,是 \(\textit{KMP}\) 算法的思想所在。

很明显在暴力算法中,模式字符串每次都需要比较,非常复杂,那么这里面有没有优化的时间呢?

这里我直接引用参考文档里的说明:

假设现在文本串 \(\textit{S}\) 匹配到 \(i\) 位置,模式串 \(\textit{P}\) 匹配到 \(j\) 位置:

  • 如果 \(j = -1\),或者当前字符匹配成功(即 \(\textit{S}[i] = \textit{P}[j]\)),都令 \(\textit{i++}\) , \(j++\),继续匹配下一个字符;

  • 如果 \(j != -1\),且当前字符匹配失败(即 \(\textit{S}[i] != \textit{P}[j]\)),则令 \(i\) 不变,\(j = next[j - 1]\)。也就是说模式串 \(\textit{P}\) 相对于文本串 \(\textit{S}\) 向右移动了 \(j - next[j]\) 位。换言之,当匹配失败时,模式串向右移动的位数为:失配字符所在位置 - 失配字符对应的next值,即移动的实际位数为:\(j - next[j]\),且此值大于等于 \(1\)

\(\textit{next}\) 数组中值的含义是当前字符之前的字符串中,最长的相同前缀后缀

例如如果 \(next[j] = k\) ,代表j之前的字符串中有最大长度为 \(k\) 的相同前缀后缀。

此也意味着在某个字符失配时,该字符对应的 \(next\) 值会告诉你下一步匹配中,模式串应该跳到哪个位置(跳到 \(next[j]\) 的位置)。如果 \(next[j]\) 等于 \(0\)\(-1\),则跳到模式串的开头字符,若 \(next[j] = k\)\(k > 0\),代表下次匹配跳到 \(j\) 之前的某个字符,而不是跳到开头,且具体跳过了 \(k\) 个字符。

如何求Next数组?

字符串前缀和后缀

字符串前缀是指不包含最后一个字符的所有以第一个字符开头的连续子串,后缀是指不包含第一个字符的所有以最后一个字符结尾的连续子串。

例如: 比如说有一个长度为 \(5\) 字符串"ababc"

其前缀有"a", "ab", "aba", "abab"; 后缀有"c", "bc", "abc", "babc", "ababc"

那么

  • 长度为前 \(1\) 个字符的子串a,最长相同前后缀的长度为 \(0\)

以此类推:

  • 长度为前 \(4\) 个字符的子串 aaba ,最长相同前后缀的长度为 \(1\)
  • 长度为前 \(5\) 个字符的子串 aabaa ,最长相同前后缀的长度为 \(2\)
  • 长度为前 \(6\) 个字符的子串 aabaaf ,最长相同前后缀的长度为 \(0\)

获取Next数组过程

这个求 \(\textit{next}\) 数组曾经困扰了我很久,一直不太能理解。

  1. \(\textit{left}\) 指针表示当前字符串后缀末尾\(\textit{right}\) 指针表示前缀数组末尾,那么很容易知道 \(\textit{right}\) 指针要从左遍历到字符串末尾;

  2. \(\textit{left}\) 指针表示字符串后缀末尾,那么 \(\textit{left} < \textit{right}\)

  3. \(\textit{needle}.charAt(left) == \textit{needle}.charAt(right)\),那么 \(\textit{left}++\)

  4. \(\textit{right}\) 遍历到字符串末尾时,也就得到了 \(\textit{next}\) 数组。

代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public int[] getNext(String needle) {
int len = needle.length();

// 定义好next数组
int[] next = new int[len];

for (int right = 1, left = 0; right < len; right++) {
// 定义好两个指针right与left
// 在for循环中初始化指针right为1,left=0,开始计算next数组,right始终在left指针的后面
while (left > 0 && needle.charAt(left) != needle.charAt(right)) {
// 如果不相等就让left指针回退,到0时就停止回退
left = next[left - 1]; //进行回退操作;
}

if (needle.charAt(left) == needle.charAt(right)) {
left++;
}

next[right] = left; // 这是从 1 开始的
}

return next;
}

KMP

\(\textit{KMP}\) 的实现其实和 \(\textit{next}\) 数组求解是相同的,只是当前匹配到相同时,跳出循环。

因为 \(\textit{next}\) 的求解本质也是用自身去匹配自身

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public int kmp(String haystack, String needle) {
// 当needle是空字符串时,返回0
if (needle == null || needle.length() == 0) {
return 0;
}

int patLen = needle.length();

int[] next = new int[patLen];

// 定义好next数组
for (int right = 1, left = 0; right < patLen; right++) {
// 定义好两个指针right与left
// 在for循环中初始化指针right为1,left=0,开始计算next数组,right始终在left指针的后面
// 用while 而不是if,因为需要逐步回退到0
while (left > 0 && needle.charAt(left) != needle.charAt(right)) {
// 如果不相等就让left指针回退,到0时就停止回退
left = next[left - 1]; //进行回退操作;
}

if (needle.charAt(left) == needle.charAt(right)) {
left++;
}

// 这是从 1 开始的
next[right] = left;
}

// 循环结束的时候,next数组就已经计算完毕了

for (int i = 0, j = 0; i < haystack.length(); i++) {
while (j > 0 && haystack.charAt(i) != needle.charAt(j)) {
j = next[j - 1];
}
if (haystack.charAt(i) == needle.charAt(j)) {
j++;
}
if (j == patLen) {
return i - patLen + 1;
}
}

return -1;
}

KMP算法复杂度分析

  • 时间复杂度:\(O(n+m)\),其中 \(n\) 是字符串 \(\textit{haystack}\) 的长度,\(m\) 是字符串 \(\textit{needle}\) 的长度。我们至多需要遍历两字符串一次。

  • 空间复杂度:\(O(m)\),其中 \(m\) 是字符串 \(\textit{needle}\) 的长度。我们只需要保存字符串 \(\textit{needle}\) 的前缀函数。

应用 & 练习

LeetCode \(\textit{KMP}\) 练习题目:

  1. 28. 实现 strStr()
  2. 214. 最短回文串
  3. 459. 重复的子字符串
  4. 686. 重复叠加字符串匹配
  5. 1392. 最长快乐前缀

小结

\(\textit{KMP}\) 算法是一个很重要的算法,但是我们不光要知其然还要知其所以然,所以需要认真吃透其方法,详细了解其具体实现,才能真正掌握这一算法。不光知道,还需要达到可以直接手写代码的水平。

参考资料

  1. Knuth–Morris–Pratt algorithm
  2. Prefix function. Knuth–Morris–Pratt algorithm
  3. 多图预警👊🏻详解 KMP 算法
  4. 一文详解 KMP 算法
  5. KMP算法详解