一、考察频率以及难度
难度:⭐⭐⭐
记忆化搜索属于动态规划的范畴,在笔试面试的考察频率都非常大。比较简单的情况下,可能出现在前3道题目,出的比较难也有可能在最后一道题目。
二、学习技巧
记忆化搜索最主要的技巧是要掌握好递归枚举,掌握好快速定义递归函数以及枚举过程是做好这一类题目的关键。当然,动态规划的模块都需要有大量的题量支撑。
三、应用场景
如果你可以将一个问题转换成可分解、可枚举的子问题,那么就可以考虑使用动态规划来做,一般是解决最大、最小、方案数。
四、算法讲解
首先,在学习这个模块之前,建议大家先学好递归,至少你对于递归过程比较清晰,当我们需要枚举或者遍历某个结构的时候,可以比较快速地写出代码。
我们先看一下这道题目:https://leetcode.cn/problems/climbing-stairs/
我们不妨使用一个基础的dfs枚举来进行枚举,结构图如下:
我们假设从0点出发,每次的选择可以是走1步或者走2步,如此一来可以抽象出一棵二叉树,我们可以利用dfs来遍历这棵二叉树,不难写出以下代码:
int dfs(int i) {
if (i == n) return 1;
if (i > n) return 0;
return dfs(i+1)+dfs(i+2);
}
这便是一个最暴力的解法,接下来我们分析一下时间复杂度,递归的参数i的范围是[0,n],对于每个i都有2个分叉,因此计算的次数接近于 2的n次方,也就是O(2^n),大家也可以这样理解:高度为n的二叉树节点数是2的n次方个。
那么如果题目给定的n比较大,是无法通过的,我们应该考虑优化。
不妨观察上述途中几个红色的节点,不难发现,这些节点的计算过程和计算结果是一致的,也就是说这个过程会有大量的重复的运算,避免这些重复的运算,就是记忆化搜索的核心了,这个思想其实也就是动态规划,代码改造如下:
int dp[N];//N是输入的数据大小
memset(dp,-1,sizeof dp); // 初始值设置为-1,目的是区分出计算过的状态和未计算的状态,-1表示未计算的状态
int dfs(int i) {
if (i == n) return 1;
if (i > n) return 0;
if (dp[i]!=-1) return dp[i];//如果该状态计算过,那么直接返回
return dp[i] = dfs(i+1)+dfs(i+2);//对于递归的结果,记录下来
}
如此一来,复杂度就变成了O(n),因为总共只有n个状态,我们只需要计算n个状态即可。
接下来我们不妨抽象一下记忆化搜索的通用思考方式:
找到一个枚举思路,写出dfs函数。dfs函数注意要将结果作为返回值返回。
根据dfs函数的参数的个数和范围,开辟相应纬度和大小的dp数组,初始值一般为-1。
递归函数之前判断状态是否计算过,递归计算结束以后记录递归结果。
注意事项:
记忆化搜索在状态相同以及枚举方式相同的情况下,时间复杂度和迭代填表的dp的复杂度是一样的,也就是笔试的时候是完全可以通过的!唯一的缺点就是有额外的空间开销,毕竟是递归就有额外的栈空间的消耗,但是这个并不影响。
对于某些需要对依赖的状态来进行优化的dp,可能记忆化搜索做不到,因为在计算递归的过程中,依赖的状态并未计算。
记忆化搜索的复杂度:假设说递归的参数是a,b,那么复杂度=a的范围*b的范围*每次递归的计算次数。
递归函数没设计好、枚举过程没设计好都可能导致复杂度超。
五、例题与解析
习题1 采集蜂蜜
https://oj.niumacode.com/problem/P1101
你是一名经验丰富的昆虫学家,正在研究一群蜜蜂的采蜜行为。这些蜜蜂飞行在一排花朵之间,每朵花上都有一定数量的蜜糖。蜜蜂们有一个独特的采蜜规则:如果两朵相邻的花同时被蜜蜂采蜜,会导致它们的花粉产生冲突,进而让蜜糖变质。
你的任务是帮助蜜蜂们设计一个采蜜策略,使得在不触发花粉冲突的情况下,蜜蜂们能够采集到最多的蜜糖。
输入:
第一行包含一个整数 ,表示花朵的数量。
第二行包含 个非负整数,表示每朵花所含的蜜糖量。
输出:
输出一个整数,表示不触动花粉冲突的情况下能够采集到的最高蜜糖量。
示例 1:
输入
4
1 2 3 1
输出
4
解释
蜜蜂采集第 1 朵花 (蜜糖量 = 1) ,然后采集第 3 朵花 (蜜糖量 = 3)。
采集到的最高蜜糖量 = 1 + 3 = 4。
示例 2:
输入
5
2 7 9 3 1
输出
12
提示:
1 <= nums.length <= 100
0 <= nums[i] <= 400
解析
我们不妨抽象一下这道题目:给定一个数组,挑选若干个互不相邻数字,使得总和最大。
接下来我们考虑如何枚举这个过程。
从头开始考虑每一个数字,每一个数字都有两种选择,要么选,要么不选,在两条不同的递归路径中选取最大的即可。式子如下
f(i) = max(f(i + 1), f(i + 2) + nums[i])
Java
import java.util.Scanner;
import java.util.Arrays;
public class Main {
static int[] dp;
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int[] nums = new int[n];
for (int i = 0; i < n; i++) {
nums[i] = scanner.nextInt();
}
scanner.close();
dp = new int[nums.length];
Arrays.fill(dp, -1);
System.out.println(dfs(0, nums));
}
static int dfs(int index, int[] nums) {
if (index >= nums.length) return 0;
if (dp[index] != -1) return dp[index];
return dp[index] = Math.max(dfs(index + 1, nums), dfs(index + 2, nums) + nums[index]);
}
}
Python
def dfs(index, nums, dp):
if index >= len(nums):
return 0
if dp[index] != -1:
return dp[index]
dp[index] = max(dfs(index + 1, nums, dp), dfs(index + 2, nums, dp) + nums[index])
return dp[index]
if __name__ == "__main__":
n = int(input())
nums = list(map(int, input().split()))
dp = [-1] * len(nums)
print(dfs(0, nums, dp))
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
vector<int> dp;
int dfs(int index, const vector<int>& nums) {
if (index >= nums.size()) return 0;
if (dp[index] != -1) return dp[index];
return dp[index] = max(dfs(index + 1, nums), dfs(index + 2, nums) + nums[index]);
}
int main() {
int n;
cin >> n;
vector<int> nums(n);
for (int i = 0; i < n; ++i) {
cin >> nums[i];
}
dp = vector<int>(nums.size(), -1);
cout <<dfs(0,nums)<< endl;
return 0;
}
习题2 魔法石碑
https://oj.niumacode.com/problem/P1102
在一个神秘的魔法世界里,有一个奇怪的魔法数字石碑,上面刻着一个正整数 。这个数字石碑有一种神奇的属性,可以通过特定的魔法操作将数字变成1。石碑的守护者希望你能找到一种方法,以最少的操作次数将石碑上的数字变成1,从而解开古老的谜题。
你可以做如下操作:
如果 n 是偶数,则用 n/2 替换 n 。
如果 n 是奇数,则可以用 n+1 或 n-1 替换 n 。
你的任务是找到将 变为 1 所需的最小替换次数。
输入:
输入包含一个正整数 。
输出:
输出一个整数,表示将 变为 1 所需的最小替换次数。
示例 1:
输入
8
输出
3
解释
8 -> 4 -> 2 -> 1
示例 2:
输入
7
输出
4
解释
7 -> 8 -> 4 -> 2 -> 1
或 7 -> 6 -> 3 -> 2 -> 1
提示:
1 <= n <= 2^31 - 1
题目解析
此题很明显可以使用递归来做,题目已经明确告诉我们递归的方向:如果n是偶数的话递归方向就是n / 2,如果是奇数的话应该是n + 1 和 n - 1。不难发现如果n是奇数,那么n - 1和n + 1是偶数,那么下一次递归方向就必然是 (n + 1) / 2 和 (n - 1) / 2,因此不难写出以下代码。
其中mem的含义是mem[n]:n变为1的最少操作数。
注意:这题如果采用自底向上的动态规划的式子如下:
if i % 2 == 0: dp[i] = dp[i / 2] + 1
else : dp[i] = min(dp[(i + 1)/2], dp[(i - 1]/2]) + 2
但是这样做的时间复杂度是O(n),而n的范围是2^31次方,所以是会超时的。采用递归的做法则是logn的复杂度,所以此题要使用递归的方式进行求解。
代码展示
Java
import java.util.HashMap;
import java.util.Scanner;
public class Main {
static HashMap<Long, Integer> mem = new HashMap<>();
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
scanner.close();
System.out.println( dfs(n));
}
private static int dfs(long n) {
if (n == 1) return 0;
if (mem.containsKey(n)) return mem.get(n);
int ans = 0;
if (n % 2 == 0) {
ans = dfs(n / 2) + 1;
} else {
ans = Math.min(dfs((n - 1) / 2), dfs((n + 1) / 2)) + 2;
}
mem.put(n, ans);
return ans;
}
}
Python
mem = {}
def dfs(n):
if n == 1:
return 0
if n in mem:
return mem[n]
if n % 2 == 0:
ans = dfs(n // 2) + 1
else:
ans = min(dfs((n - 1) // 2), dfs((n + 1) // 2)) + 2
mem[n] = ans
return ans
if __name__ == "__main__":
n = int(input())
print(dfs(n))
C++
#include <iostream>
#include <unordered_map>
using namespace std;
unordered_map<long long, int> mem;
int dfs(long long n) {
if (n == 1) return 0;
if (mem.count(n)) return mem[n];
int ans = 0;
if (n % 2 == 0) {
ans = dfs(n / 2) + 1;
} else {
ans = min(dfs((n - 1) / 2), dfs((n + 1) / 2)) + 2;
}
mem[n] = ans;
return ans;
}
int main() {
int n;
cin >> n;
cout << dfs(n) << endl;
return 0;
}
习题3 小马的数组构造 【美团】
https://oj.niumacode.com/problem/P1150
小马拿到了一个数组a,她准备构造一个数组b满足:
1. b的每一位都和a对应位置不同,即 bi != ai
2. b 的所有元素之和都和 a 相同。
3. b的数组均为正整数。请你告诉小马有多少种构造方式。由于答案过大,请对 10^9+7取模。
输入描述
第一行输入一个正整数n,代表数组的大小。第二行输入n个正整数ai,代表小美拿到的数组。
1<=n<=100, 1<=ai<=300, 1<=Σai <= 500
输出描述
一个整数,代表构造方式对 10^9+7取模的值。
示例1
输入
3
1 1 3
输出
1
说明
只有[2,2,1]这一种数组合法。
示例2
输入
3
1 1 1
输出
0
思路与代码
记忆化搜索。
对于每一个位置来说,可以填充的数据是 当前剩余的数字且不等于原数组的数字。(这个数据规模很明显都指向这个角度的DP)。
定义 dfs(i, j)
函数,其中 i
表示当前处理到 a
数组的第几个元素,j
表示剩余需要分配的值。
dfs(i, j)
的返回值是以 a[i]
为起始元素,剩余总和为 j
的满足条件的组合数。
如果 i
达到数组末尾 n-1
,则判断剩余的 j
是否为 a[-1]
,返回合法性。
否则,遍历从 1
到 j
的所有可能值作为 b[i]
,确保 b[i] != a[i]
,递归计算下一步的可能性并累加。
使用取模操作防止结果溢出,并保存计算结果以便下次直接使用。
CPP
#include <iostream>
#include <vector>
using namespace std;
const int MAX_N = 105;
const int MAX_SUM = 505;
const int MOD = 1000000007;
int n;
vector<int> a;
int dp[MAX_N][MAX_SUM];
int dfs(int i, int j) {
if (i >= n - 1) {
return (j > 0 && j != a[n - 1]) ? 1 : 0;
}
if (dp[i][j] != -1) {
return dp[i][j];
}
int cnt = 0;
for (int c = 1; c <= j; ++c) {
if (c != a[i]) {
cnt += dfs(i + 1, j - c);
cnt %= MOD;
}
}
dp[i][j] = cnt;
return cnt;
}
int main() {
cin >> n;
a.resize(n);
int sum_a = 0;
for (int i = 0; i < n; ++i) {
cin >> a[i];
sum_a += a[i];
}
for (int i = 0; i < MAX_N; ++i) {
for (int j = 0; j < MAX_SUM; ++j) {
dp[i][j] = -1;
}
}
cout << dfs(0, sum_a) << endl;
return 0;
}
Java
import java.util.*;
public class Main {
static final int MAX_N = 105;
static final int MAX_SUM = 505;
static final int MOD = 1000000007;
static int n;
static int[] a;
static int[][] dp;
static int dfs(int i, int j) {
if (i >= n - 1) {
return (j > 0 && j != a[n - 1]) ? 1 : 0;
}
if (dp[i][j] != -1) {
return dp[i][j];
}
int cnt = 0;
for (int c = 1; c <= j; ++c) {
if (c != a[i]) {
cnt += dfs(i + 1, j - c);
cnt %= MOD;
}
}
dp[i][j] = cnt;
return cnt;
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
n = scanner.nextInt();
a = new int[n];
dp = new int[MAX_N][MAX_SUM];
int sum_a = 0;
for (int i = 0; i < n; ++i) {
a[i] = scanner.nextInt();
sum_a += a[i];
}
for (int i = 0; i < MAX_N; ++i) {
Arrays.fill(dp[i], -1);
}
System.out.println(dfs(0, sum_a));
scanner.close();
}
}
Python
n = int(input())
a = [int(c) for c in input().split()]
dp = {}
def dfs(i,j):
if i >= n - 1:
return 1 if j > 0 and j != a[-1] else 0
if (i,j) in dp: return dp[(i,j)]
cnt = 0
for c in range(1,j+1):
if c != a[i]:
cnt += dfs(i+1, j-c)
cnt %= 10**9 + 7
dp[(i, j)] = cnt % (10**9 + 7)
return cnt
print(dfs(0, sum(a)))
习题4 二叉查找树的个数 【华为】
https://oj.niumacode.com/problem/P1162
二叉查找树,是具有下列性质的二叉树: 若它的左子树不空,则左子树上所有结点的值均小于它的根结点的值;若它的右子树不空,则右子树上所有结点的值均大于它的根结点的值; 它的左、右子树也分别为二叉查找树。
给定一个数n,表示值由1到n的节点构造成二叉查找树,问对应能构造的高度小于等于k的不同二叉查找树的个数,根节点的高度为1。0< n < 36,0< k< 36.
解答要求
时间限制: C/C++ 1000ms, 其他语言: 2000ms 内存限制: C/C++ 256MB,其他语言: 512MB
输入
树的节点个数n,树的高度k,用空格分割。
输出
不同二又查找树的个数。
样例1
输入
3 2
输出
1
思路与代码
这个题目需要构造高度小于等于k的不同二叉查找树,我们可以使用记忆化搜索的方法来解决。
定义状态:
f[l, r, k]
表示在区间[l, r]
内,高度不大于k
的二叉查找树的数量。
状态转移:
我们可以枚举当前子树的根节点
root
,将区间[l, r]
划分成两个子区间[l, root-1]
和[root+1, r]
。在这种划分下,左子树和右子树的高度最大都不能超过
k-1
,因此状态转移方程为:
f[l,r,k] += f[l,root-1,k-1] + f[root+1,r,k-1]
边界条件:
当
l > r
时,返回 1,因为空树也是一种有效的二叉查找树。当
lvl == 1
时,如果l == r
返回 1,因为这是一个单节点的树,否则返回 0,因为无法形成高度为 1 的树。当
l == r
时,返回 1,因为这也是一个单节点的树。
cpp
#include <iostream>
#include <vector>
#define ll long long
using namespace std;
// 定义记忆化数组,假设最大 n 和 k 的值为 36
ll memo[37][37][37];
ll dfs(int l, int r, int lvl) {
if (l > r) {
return 1;
}
if (lvl == 1) {
if (l == r) {
return 1;
}
return 0;
}
if (l == r) {
return 1;
}
if (memo[l][r][lvl] != -1) {
return memo[l][r][lvl];
}
ll ans = 0;
for (int root = l; root <= r; root++) {
ans += dfs(l, root - 1, lvl - 1) * dfs(root + 1, r, lvl - 1);
}
memo[l][r][lvl] = ans;
return ans;
}
int main() {
int n, k;
cin >> n >> k;
// 初始化记忆化数组
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= n; j++) {
for (int l = 0; l <= k; l++) {
memo[i][j][l] = -1;
}
}
}
cout << dfs(1, n, k) << endl;
return 0;
}
Java
import java.util.Scanner;
public class Main {
// 定义记忆化数组,假设最大 n 和 k 的值为 36
private static long[][][] memo;
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int k = scanner.nextInt();
scanner.close();
// 初始化记忆化数组
memo = new long[n + 1][n + 1][k + 1];
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= n; j++) {
for (int l = 0; l <= k; l++) {
memo[i][j][l] = -1;
}
}
}
System.out.println(dfs(1, n, k));
}
private static long dfs(int l, int r, int lvl) {
if (l > r) {
return 1;
}
if (lvl == 1) {
if (l == r) {
return 1;
}
return 0;
}
if (l == r) {
return 1;
}
if (memo[l][r][lvl] != -1) {
return memo[l][r][lvl];
}
long ans = 0;
for (int root = l; root <= r; root++) {
ans += dfs(l, root - 1, lvl - 1) * dfs(root + 1, r, lvl - 1);
}
memo[l][r][lvl] = ans;
return ans;
}
}
Python
from functools import cache
n,k = map(int, input().split())
@cache
def dfs(l:int, r:int, lvl:int) -> int:
if l > r: return 1
if lvl == 1:
if l == r: return 1
return 0
if l == r: return 1
ans = 0
for root in range(l, r+1):
ans += dfs(l,root-1,lvl-1) * dfs(root+1, r, lvl-1)
return ans
print(dfs(1,n,k))
习题5 项目派遣 【华为】
https://oj.niumacode.com/problem/P1165
题目描述:
某公司有n名员工,第i名员工具有的能力可以用一个正整数ai描述,称为员工的能力值,现在,公司有一个项目需要交给恰好[n/2]名员工负责。为了保证项目能顺利进行,要求负责该项目的所有员工能力值之和大于等于x。
公司希望你可以帮忙求出,有多少种不同的派遣员工来负责这个项目的方案。
上文中,[x]风表示大于等于x的最小整数,例[4] =4,[4.2]=5。认为两个方案不同,当且仅当存在一名员工在一种方案中负责该项目,而在另一种方案中不负责.
输入描述
输入包含多组数据,输入第一行包含一个整数T (1<=T<=10) ,表示数据组数.
接下来2T行,每两行描述了一组数据.
每组数据第一行包含两个正整数n(1<=n<=16) 和x (1<=x<=2*10^4),分别表示公司的员工总数和项目对负责员工能力值之和的要求。
每组数据第二行包含n个整数,第i个整数表示第i名员工的能力值ai(1<=ai<=10^3)。
对于100%的数据,满足1<=n<=16,1<=x<=2*10^4,1<=T<=10,1<=ai<=10^3。
输出描述
输出包含T行。对于每组数据输出一行一个整数,表示可行的派遣方案数.
样例输入
3
5 10
3 2 3 4 5
3 3
1 1 1
10 10
3 1 2 8 5 4 2 9 12 7
样例输出
7
0
252
提示
对于样例的第一组数据,在所有选择3名员工的方案中,有3种选择方案不可行:
1.选择第1、2、3名员工
2.选择第1、2、4名员工
3.选择第2、3、4名员工
其余7种方案均可行,因此答案为7。
对于样例的第二组数据,所有选择2名员工的方案均不可行,因此答案为0。
思路与代码
核心就是计算找到n/2个员工,使得能力总和大于等于x的方案数。
n
最多为 16,这允许我们使用记忆化搜索的方法,因为组合数的数量在此范围内是可处理的。
dfs[i][j][k]
表示从第 i
个员工开始,当前能力总和为 j
,选择的员工人数为 k
时,满足条件的方案数。
每个员工有两种选择:选择或不选择。可以转移到如下状态:
不选择第
i
个员工:f[i][j][k] = f[i+1][j][k]
选择第
i
个员工:f[i][j][k] = f[i+1][j+A[i]][k+1]
,前提是选择人数k
不超过⌈n/2⌉
边界条件
当选够了目标人数且能力值满足要求:返回 1
当已经考虑完所有员工或选择人数超过目标值:返回 0
C++
#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
using namespace std;
const int MAXN = 17; // n <= 16
const int MAXX = 20001; // x <= 2 * 10^4
int dp[MAXN][MAXX][MAXN]; // dp[i][j][k] 表示使用第 i 个员工、总能力值为 j、选择人数为 k 的方案数
int dfs(int i, int j, int k, int n, int x, const vector<int>& A, int target) {
if (j >= x && k == target) return 1;
if (i >= n || k > target) return 0;
if (dp[i][j][k] != -1) return dp[i][j][k];
// 不选择第 i 个员工或选择第 i 个员工
int result = dfs(i + 1, j, k, n, x, A, target) + dfs(i + 1, j + A[i], k + 1, n, x, A, target);
dp[i][j][k] = result;
return result;
}
int main() {
int T;
cin >> T;
while (T--) {
int n, x;
cin >> n >> x;
vector<int> A(n);
for (int i = 0; i < n; ++i) {
cin >> A[i];
}
int target = (n + 1) / 2; // math.ceil(n / 2)
// 初始化 dp 数组
memset(dp, -1, sizeof(dp));
cout << dfs(0, 0, 0, n, x, A, target) ;
if (T != 0) cout << endl;
}
return 0;
}
Java
import java.util.Scanner;
public class Main {
static final int MAXN = 17; // n <= 16
static final int MAXX = 20001; // x <= 2 * 10^4
static int[][][] dp = new int[MAXN][MAXX][MAXN];
public static int dfs(int i, int j, int k, int n, int x, int[] A, int target) {
if (j >= x && k == target) return 1;
if (i >= n || k > target) return 0;
if (dp[i][j][k] != -1) return dp[i][j][k];
// 不选择第 i 个员工或选择第 i 个员工
int result = dfs(i + 1, j, k, n, x, A, target) + dfs(i + 1, j + A[i], k + 1, n, x, A, target);
dp[i][j][k] = result;
return result;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int T = sc.nextInt();
while (T-- > 0) {
int n = sc.nextInt();
int x = sc.nextInt();
int[] A = new int[n];
for (int i = 0; i < n; ++i) {
A[i] = sc.nextInt();
}
int target = (n + 1) / 2; // math.ceil(n / 2)
// 初始化 dp 数组
for (int[][] array2D : dp) {
for (int[] array1D : array2D) {
java.util.Arrays.fill(array1D, -1);
}
}
System.out.print(dfs(0, 0, 0, n, x, A, target));
if (T != 0) System.out.println();
}
sc.close();
}
}
Python
import math
from functools import cache
T = int(input())
for _ in range(T):
n,x = map(int, input().split())
A = [int(c) for c in input().split()]
""" 恰好选择 [n//2] 个员工,使得总和大于等于x """
target = math.ceil(n/2)
""" i下标 j能力值 k人数 """
@cache
def dfs(i,j,k):
if j>=x and k==target:return 1
if i>=n or k>target: return 0
return dfs(i+1,j,k) + dfs(i+1,j+A[i],k+1)
print(dfs(0,0,0))