[LeetCode][18. 四数之和] 4种方法:暴力,双指针,DFS,HashMap
By Long Luo
方法一:暴力枚举
思路与算法:
和 15. 三数之和 类似,我们先对数组进行排序,然后 \(4\) 层循环即可。
由于结果肯定会出现重复的数字,所以我们使用 \(\texttt{Set}\) 来去重,代码如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22public List<List<Integer>> fourSum(int[] nums, int target) {
	if (nums == null || nums.length < 4) {
		return new ArrayList<>();
	}
	Arrays.sort(nums);
	int n = nums.length;
	Set<List<Integer>> ans = new HashSet<>();
	for (int first = 0; first < n - 3; first++) {
		for (int second = first + 1; second < n - 2; second++) {
			for (int third = second + 1; third < n - 1; third++) {
				for (int fourth = third + 1; fourth < n; fourth++) {
					if (nums[first] + nums[second] + nums[third] + nums[fourth] == target) {
						ans.add(Arrays.asList(nums[first], nums[second], nums[third], nums[fourth]));
					}
				}
			}
		}
	}
	return new ArrayList<>(ans);
}
我们可以在每次循环中增加判断,防止出现重复四元组,使用 \(\texttt{List}\):1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62public List<List<Integer>> fourSum(int[] nums, int target) {
    if (nums == null || nums.length < 4) {
        return new ArrayList<>();
    }
    Arrays.sort(nums);
    int len = nums.length;
    List<List<Integer>> ans = new ArrayList<>();
    for (int i = 0; i < len - 3; i++) {
        if (i > 0 && nums[i] == nums[i - 1]) {
            continue;
        }
        if ((long)nums[i] + nums[i + 1] + nums[i + 2] + nums[i + 3] > target) {
            break;
        }
        if ((long)nums[i] + nums[len - 3] + nums[len - 2] + nums[len - 1] < target ) {
            continue;
        }
        for (int j = i + 1; j < len - 2; j++) {
            if (j > i + 1 && nums[j] == nums[j - 1]) {
                continue;
            }
            if ((long)nums[i] + nums[j] + nums[j + 1] + nums[j + 2] > target) {
                break;
            }
            if ((long)nums[i] + nums[j] + nums[len - 2] + nums[len - 1] < target) {
                continue;
            }
            for (int k = j + 1; k < len - 1; k++) {
                if (k > j + 1 && nums[k] == nums[k - 1]) {
                    continue;
                }
                if ((long)nums[i] + nums[j] + nums[k] + nums[k + 1] > target) {
                    break;
                }
                if ((long)nums[i] + nums[j] + nums[k] + nums[len - 1] < target) {
                    continue;
                }
                for (int l = k + 1; l < len; l++) {
                    if (l > k + 1 && nums[l] == nums[l - 1]) {
                        continue;
                    }
                    if (nums[i] + nums[j] + nums[k] + nums[l] == target) {
                        ans.add(Arrays.asList(nums[i], nums[j], nums[k], nums[l]));
                    }
                }
            }
        }
    }
    return ans;
}
复杂度分析:
- 时间复杂度:\(O(n^4)\) ,其中 \(n\) 是数组 \(\textit{nums}\) 的长度。
- 空间复杂度:\(O(\log n)\) , 空间复杂度主要取决于排序额外使用的空间 \(O(\log n)\) 。
方法二:双指针
思路与算法:
使用两重循环分别枚举前两个数,然后在两重循环枚举到的数之后使用双指针枚举剩下的两个数。
假设两重循环枚举到的前两个数分别位于下标 \(i\) 和 \(j\),其中 \(i \lt j\)。初始时,左右指针分别指向下标 \(j + 1\) 和下标 \(n - 1\)。
每次计算四个数的和,并进行如下操作:
- 如果 \(sum == target\) ,则将枚举到的四个数加到答案中,然后将左指针右移直到遇到不同的数,将右指针左移直到遇到不同的数;
- 如果 \(sum \lt target\) ,则将左指针右移一位;
- 如果 \(sum \gt target\) ,则将右指针左移一位。
使用双指针枚举剩下的两个数的时间复杂度是 \(O(n)\),因此总时间复杂度是 \(O(n^3)\) 。
具体实现时,还可以进行一些剪枝操作:
- 在确定第一个数之后,如果 \(nums[i]+nums[i+1]+nums[i+2]+nums[i+3] > \textit{target}\),说明此时剩下的三个数无论取什么值,四数之和一定大于 \(\textit{target}\),因此退出第一重循环; 
- 在确定第一个数之后,如果 \(nums[i]+nums[n-3]+nums[n-2]+nums[n-1] < \textit{target}\),说明此时剩下的三个数无论取什么值,四数之和一定小于 \(\textit{target}\),因此第一重循环直接进入下一轮,枚举\(nums[i+1]\); 
- 在确定前两个数之后,如果 \(nums[i]+nums[j]+nums[j+1]+nums[j+2] >\textit{target}\),说明此时剩下的两个数无论取什么值,四数之和一定大于 \(\textit{target}\),因此退出第二重循环; 
- 在确定前两个数之后,如果 \(nums[i]+nums[j]+nums[n-2]+nums[n-1] <\textit{target}\),说明此时剩下的两个数无论取什么值,四数之和一定小于 \(\textit{target}\),因此第二重循环直接进入下一轮,枚举 \(nums[j+1]\)。 
需要注意的是:由于可能出现的溢出,对数据需要转换成 \(\texttt{long}\) 型。
代码如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64public static List<List<Integer>> fourSum(int[] nums, int target) {
	if (nums == null || nums.length < 4) {
		return new ArrayList<>();
	}
	Arrays.sort(nums);
	int n = nums.length;
	List<List<Integer>> ans = new ArrayList<>();
	for (int first = 0; first < n - 3; first++) {
		if (first > 0 && nums[first] == nums[first - 1]) {
			continue;
		}
		// 部分用例溢出,改成target减去形式
		if (nums[first] + nums[first + 1] > target - nums[first + 3] - nums[first + 2]) {
			break;
		}
		// 部分用例溢出,改成target减去形式
		if (nums[first] + nums[n - 3] < target - nums[n - 1] - nums[n - 2]) {
			continue;
		}
		for (int second = first + 1; second < n - 2; second++) {
			if (second > first + 1 && nums[second] == nums[second - 1]) {
				continue;
			}
			// 部分用例溢出,改成target减去形式
			if (nums[first] + nums[second] > target - nums[second + 2] - nums[second + 1]) {
				break;
			}
			// 部分用例溢出,改成target减去形式
			if (nums[first] + nums[second] < target - nums[n - 1] - nums[n - 2]) {
				continue;
			}
			int left = second + 1;
			int right = n - 1;
			while (left < right) {
				// 防止溢出情况
				long sum = (long) nums[first] + nums[second] + nums[left] + nums[right];
				if (sum > target) {
					right--;
				} else if (sum < target) {
					left++;
				} else if (sum == target) {
					ans.add(Arrays.asList(nums[first], nums[second], nums[left], nums[right]));
					left++;
					right--;
					while (left < right && nums[left] == nums[left - 1]) {
						left++;
					}
					while (left < right && nums[right] == nums[right + 1]) {
						right--;
					}
				}
			}
		}
	}
	return ans;
}
复杂度分析:
- 时间复杂度:\(O(n^3)\) ,其中 \(n\) 是数组的长度。排序的时间复杂度是 \(O(n\log n)\),枚举四元组的时间复杂度是 \(O(n^3)\) ,总时间复杂度为 \(O(n^3 + n\log n) = O(n^3)\) 。
- 空间复杂度:\(O(\log n)\) ,其中 \(n\) 是数组的长度。 空间复杂度主要取决于排序额外使用的空间。此外排序修改了输入数组 \(\textit{nums}\) ,实际情况中不一定允许,因此也可以看成使用了一个额外的数组存储了数组 \(\textit{nums}\) 的副本并排序,空间复杂度为 \(O(n)\) 。
方法三:DFS
思路与算法:
这道题也可以使用回溯法,其实本质还是暴力法,主要问题就是参考暴力法进行剪枝。
可以使用 \(\texttt{Set}\) 来降低一个循环,这里也不写了,也很简单,大家参考代码看下思路就行。
代码如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45public List<List<Integer>> fourSum(int[] nums, int target) {
    if (nums == null || nums.length < 4) {
        return new ArrayList<>();
    }
    Arrays.sort(nums);
    List<List<Integer>> ans = new ArrayList<>();
    dfs(nums, ans, new ArrayList<>(), 0, target);
    return ans;
}
public void dfs(int[] nums, List<List<Integer>> res, List<Integer> list, int start, int target) {
    if (list.size() == 4 && target == 0) {
        res.add(new ArrayList<>(list));
        return;
    }
    if (list.size() >= 4) {
        return;
    }
    int len = nums.length;
    for (int i = start; i < len; i++) {
        if (len - i < 4 - list.size()) {
            return;
        }
        if (i > start && nums[i] == nums[i - 1]) {
            continue;
        }
        if (i < len - 1 && nums[i] + (long) (3 - list.size()) * nums[i + 1] > target) {
            return;
        }
        if (i < len - 1 && nums[i] + (long) (3 - list.size()) * nums[len - 1] < target) {
            continue;
        }
        list.add(nums[i]);
        dfs(nums, res, list, i + 1, target - nums[i]);
        list.remove(list.size() - 1);
    }
}
复杂度分析:
- 时间复杂度:\(O(n^4)\) ,其中 \(n\) 是数组的长度。排序的时间复杂度是 \(O(n\log n)\) ,枚举四元组的时间复杂度是 \(O(n^4)\) ,总时间复杂度为 \(O(n^4 + n\log n)=O(n^4)\) 。
- 空间复杂度:\(O(n)\) ,其中 \(n\) 是数组的长度。
方法四:HashMap
思路与算法:
使用 \(\texttt{HashMap}\) 的话,和 15. 三数之和 一致,我们可以将复杂度降到 \(O(n^3)\) 。
但可以更进一步,用空间换时间,降低枚举的复杂度,提升构建 \(\texttt{Hash}\) 的复杂度。
枚举前两个数 + 哈希后两个数,两个部分的时间复杂度都是 \(O(n^2)\) ,从而使总体的时间复杂度降为 \(O(n^2)\) ,但特殊情况下,如果数组元素都一致或者前后对称的话,时间复杂度还是 \(O(n^3)\) 。
代码如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82public List<List<Integer>> fourSum(int[] nums, int target) {
    if (nums == null || nums.length < 4) {
        return new ArrayList<>();
    }
    Arrays.sort(nums);
    int len = nums.length;
    List<List<Integer>> ans = new ArrayList<>();
    Map<Integer, List<int[]>> map = new HashMap<>();
    for (int j = len - 1; j > 2; j--) {
        if (j < len - 1 && nums[j] == nums[j + 1]) {
            continue;
        }
        if (nums[j] < target / 4) {
            break;
        }
        if ((long) nums[j] + 3L * nums[0] > target) {
            continue;
        }
        for (int i = j - 1; i > 1; i--) {
            if (i < j - 1 && nums[i] == nums[i + 1]) {
                continue;
            }
            if ((long) nums[j] + 3L * nums[i] < target) {
                break;
            }
            if (nums[j] + nums[i] > target - 2 * nums[0]) {
                continue;
            }
            int sum = nums[i] + nums[j];
            List<int[]> list = map.getOrDefault(sum, new ArrayList<>());
            list.add(new int[]{i, j});
            map.put(sum, list);
        }
    }
    for (int i = 0; i < len - 3; i++) {
        if (i > 0 && nums[i] == nums[i - 1]) {
            continue;
        }
        if (nums[i] > target / 4) {
            break;
        }
        if ((long) nums[i] + 3L * nums[len - 1] < target) {
            continue;
        }
        for (int j = i + 1; j < len - 2; j++) {
            if (j > i + 1 && nums[j] == nums[j - 1]) {
                continue;
            }
            if ((long) nums[i] + 3L * nums[j] > target) {
                break;
            }
            if (2 * nums[len - 1] < target - nums[i] - nums[j]) {
                continue;
            }
            int newTarget = target - nums[i] - nums[j];
            if (map.containsKey(newTarget)) {
                List<int[]> list = map.get(newTarget);
                for (int[] index : list) {
                    if (j < index[0]) {
                        ans.add(Arrays.asList(nums[i], nums[j], nums[index[0]], nums[index[1]]));
                    }
                }
            }
        }
    }
    return ans;
}
复杂度分析:
- 时间复杂度:\(O(n^3)\) ,其中 \(n\) 是数组的长度。排序的时间复杂度是 \(O(n \log n)\) ,两个数循环 \(O(n^2)\) ,但由于需要枚举列表,极端情况下需要 \(O(n)\) ,所以总时间复杂度为 \(O(n^3 + n \log n) = O(n^3)\) 。
- 空间复杂度:\(O(n)\),排序额外使用的空间 \(O(\log n)\) ,另外需要 \(O(n)\) 存储后面2个数之和,\(O(\log n) + O(n) = O(n)\) 。
All suggestions are welcome. If you have any query or suggestion please comment below. Please upvote👍 if you like💗 it. Thank you:-)
Explore More Leetcode Solutions. 😉😃💗