数据结构和算法模板

并查集

Java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class UnionFind {
private int[] parent; // 每个节点的parent域指向其父节点的下标。其实就是树的双亲表示法

public UnionFind(int n) {
parent = new int[n];
for (int i = 0; i < n; ++i) {
parent[i] = i; // 初始时每个节点都是自己的根节点,自成一树
}
}

// 将返回下标为x的节点的根节点的下标
// 在find的过程中进行路径压缩
public int find(int x) {
while (parent[x] != x) {
parent[x] = parent[parent[x]]; // 这行代码用于隔代压缩
x = parent[x];
}
return x;
}

// 将下标为x和下标为y的两个节点各自所在的子树进行合并
public void union(int x, int y) {
// 分别找到两个节点各自的根节点
int a = find(x);
int b = find(y);
// 不必进行按秩合并了
parent[a] = b;
}

// 判断下标为x和y的两个节点是否在同一个集合中
// i.e. 是否具有相同的根节点
public boolean isConnected(int x, int y) {
return find(x) == find(y);
}
}

Go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// UnionFind usage:
//
// uf := &UnionFind{}
//
// uf.init(n)
type UnionFind struct {
n int
parent []int
size []int
}

func (union *UnionFind) init(sz int) {
union.n = sz
union.parent = make([]int, sz)
union.size = make([]int, sz)
for i := 0; i < sz; i++ {
union.parent[i] = i
union.size[i] = 1
}
}
func (union *UnionFind) find(x int) int {
for union.parent[x] != x {
union.parent[x] = union.parent[union.parent[x]] // 隔代压缩
x = union.parent[x]
}
return x
}
func (union *UnionFind) union(x, y int) {
if union.isConnected(x, y) {
// 这个判断是必要的,否则会导致size数组计算错误
return
}
rootX := union.find(x)
rootY := union.find(y)
szX := union.size[rootX]
szY := union.size[rootY]
// 按秩合并
if szX > szY {
union.parent[rootY] = rootX
union.size[rootX] += szY
} else {
union.parent[rootX] = rootY
union.size[rootY] += szX
}
}
func (union *UnionFind) isConnected(x, y int) bool {
return union.find(x) == union.find(y)
}
func (union *UnionFind) getCap(x int) int {
return union.size[x]
}

线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class SegmentTree {
int[] nums; // 需要为其构建线段树的原输入数组,大小为n
// 线段树的顺序存储表示,大小为4n,三个数组用来支持对区间和,区间最值的对数复杂度的查询
int[] treeSum;
int[] treeMin;
int[] treeMax;
int n;
public SegmentTree(int[] nums) {
this.nums = nums;
this.n = nums.length;
this.treeSum = new int[4 * n]; // 本示例中tree中存储的值是特定区间的元素和
build(1, 0, n - 1); // 为了表示方便,完全二叉树顺序存储数组的下标从1开始
}
private void build(int idx, int left, int right) {
// idx: tree数组中节点的下标
// left, right: tree[idx]所代表的在nums数组中的区间
// build方法本质上是二叉树的遍历
if (left == right) {
// 递归基线
treeSum[idx] = nums[left]; // 叶节点,代表的区间长度为1
treeMin[idx] = nums[left];
treeMax[idx] = nums[left];
return;
}
int mid = left + (right - left) / 2; // 分治法,联想归并排序的分治
build(idx * 2, left, mid);
build(idx * 2 + 1, mid + 1, right);
pushUp(idx);
}

// 自底向上更新父节点,亦可理解为回溯
private void pushUp(int idx) {
treeSum[idx] = treeSum[idx * 2] + treeSum[idx * 2 + 1];
treeMax[idx] = Math.max(treeMax[idx * 2], treeMax[idx * 2 + 1]);
treeMin[idx] = Math.min(treeMin[idx * 2], treeMin[idx * 2 + 1]);
}

// 执行 nums[i]+=val,单点增量
private void add(int idx, int left, int right, int val, int i) {
if (left == right) {
treeSum[idx] += val;
treeMax[idx] = treeSum[idx];
treeMin[idx] = treeSum[idx];
return;
}
int mid = left + (right - left) / 2;
if (i <= left) {
// 递归左子树
add(idx * 2, left, mid, val, i);
} else {
add(idx * 2 + 1, mid + 1, right, val, i);
}
pushUp(idx); // 更新区间信息
}

// 供外部调用的add操作,单点增量
public void add(int i, int val) {
add(1, 0, n - 1, val, i);
// 可选的实时更新nums操作
// nums[i] += val;
}

// 执行 nums[i]=val,
private void update(int idx, int left, int right, int val, int i) {
if (left == right) {
treeSum[idx] = val;
treeMin[idx] = val;
treeMax[idx] = val;
return;
}
int mid = left + (right - left) / 2;
if (i <= mid) {
update(idx * 2, left, mid, val, i);
} else {
update(idx * 2 + 1, mid + 1, right, val, i);
}
pushUp(idx);
}

// 供外部调用的update操作
public void update(int i, int val) {
update(1, 0, n - 1, val, i);
// 可选的实时更新update操作
// nums[i] = val;
}

// 单点查询,查询nums[i]的大小
private int querySum(int idx, int left, int right, int i) {
if (left == right) {
// 可断言left=right=i
return treeSum[left];
}
int mid = left + (right - left) / 2;
if (i <= mid) {
return querySum(idx * 2, left, mid, i);
} else {
return querySum(idx * 2 + 1, mid + 1, right, i);
}
}

// 供外部调用的单点查询
public int querySum(int i) {
// 如果上述过程中采用了实时更新,那么可以直接返回nums中的值
// return nums[i];
return querySum(1, 0, n - 1, i);
}

// 查询区间和
private int querySum(int idx, int left, int right, int L, int R) {
if (left >= L && right <= R) {
// 当前区间完全在欲查询的区间内
return treeSum[idx]; // 返回当前区间和
}
int mid = left + (right - left) / 2;
int ans = 0;
if (L <= mid) {
ans += querySum(idx * 2, left, mid, L, R);
}
if (R > mid) {
ans += querySum(idx * 2 + 1, mid + 1, right, L, R);
}
return ans;
}

// 供外部调用的区间查询(求和)
public int querySum(int L, int R) {
return querySum(1, 0, n - 1, L, R);
}

// 查询区间最值
private int min(int idx, int left, int right, int L, int R) {
if (left >= L && right <= R) {
return treeMin[idx];
}
int mid = left + (right - left) / 2;
int leftMin = Integer.MAX_VALUE;
int rightMin = Integer.MAX_VALUE;
if (L <= mid) {
leftMin = min(idx * 2, left, mid, L, R);
}
if (R > mid) {
rightMin = min(idx * 2 + 1, mid + 1, right, L, R);
}
return Math.min(leftMin, rightMin);
}
private int max(int idx, int left, int right, int L, int R) {
if (left >= L && right <= R) {
return treeMax[idx];
}
int mid = left + (right - left) / 2;
int leftMax = Integer.MIN_VALUE;
int rightMax = Integer.MIN_VALUE;
if (L <= mid) {
leftMax = max(idx * 2, left, mid, L, R);
}
if (R > mid) {
rightMax = max(idx * 2 + 1, mid + 1, right, L, R);
}
return Math.max(leftMax, rightMax);
}
public int min(int L, int R) {
return min(1, 0, n - 1, L, R);
}
public int max(int L, int R) {
return max(1, 0, n - 1, L, R);
}

}

前缀树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class TrieNode {
boolean isWord;
TrieNode[] children;

TrieNode() {
this.isWord = false;
this.children = new TrieNode[26];
}
}

class Trie {
TrieNode root;

public Trie() {
this.root = new TrieNode();
}

public void insert(String word) {
var node = this.root;
for (int i = 0; i < word.length(); i++) {
char cur = word.charAt(i);
if (node.children[cur - 'a'] == null) {
node.children[cur - 'a'] = new TrieNode();
}
node = node.children[cur - 'a'];
}
node.isWord = true;
}

public boolean search(String word) {
var node = this.root;
for (int i = 0; i < word.length(); i++) {
char cur = word.charAt(i);
if (node.children[cur - 'a'] == null) {
return false;
}
node = node.children[cur - 'a'];
}
return node.isWord;
}

public boolean startsWith(String prefix) {
var node = this.root;
for (int i = 0; i < prefix.length(); i++) {
char cur = prefix.charAt(i);
if (node.children[cur - 'a'] == null) {
return false;
}
node = node.children[cur - 'a'];
}
return true;
}
}

最短路

Dijkstra

均假设图以邻接矩阵作为入参

不带堆优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public int[] dijkstra(int n, int k, int[][] graph) {
// graph是一个n*n的邻接矩阵,k是起点
int[] dis = new int[n];
Arrays.fill(dis, Integer.MAX_VALUE / 2);
dis[k] = 0;
boolean[] done = new boolean[n];
for (int p = 0; p < n; p++) {
// 由于规定了循环次数,图不连通也不会有影响
// 每次找到「dis[t]最小」且「未被更新」的节点t
int t = -1;
for (int i = 0; i < n; i++) {
if (!done[i] && (t == -1 || dis[i] < dis[t])) {
// TODO:这一步可以使用堆进行优化
t = i;
}
}
done[t] = true;
// 用点t的dis更新其他dis
for (int i = 0; i < n; i++) {
dis[i] = Math.min(dis[i], dis[t] + graph[t][i]);
}
}
return dis;
}

带堆优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
private int[] dijkstra(int n, int start, int[][] graph) {
// define graph as an n*n adjacent table
int[] dist = new int[n];
Arrays.fill(dist, Integer.MAX_VALUE / 2);
dist[start] = 0;
boolean[] done = new boolean[n];
// a[0] is the node number, a[1] is the cost
PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[1]));
pq.add(new int[]{start, 0});
while (!pq.isEmpty()) {
int[] cur = pq.poll();
int node = cur[0];
if (done[node]) {
continue;
}
done[node] = true;
for (int i = 0; i < n; i++) {
if (done[i]) {
continue;
}
if (dist[node] + graph[node][i] < dist[i]) {
dist[i] = dist[node] + graph[node][i];
pq.add(new int[]{i, dist[i]});
}
}
}
return dist;
}

Floyd

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public int[][] floyd(int n, int[][] graph) {
int[][] dis = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
dis[i][j] = graph[i][j];
}
}
for (int[] edge: graph) {
int from = edge[0];
int to = edge[1];
int cost = edge[2];
}
for (int k = 0; k < n; k++) { // 枚举中转点
for (int i = 0; i < n; i++) { // 枚举起点
for (int j = 0; j < n; j++) { // 枚举终点
dis[i][j] = Math.min(dis[i][j], dis[i][k] + dis[k][j]);
}
}
}
return dis;
}

快速幂

使用二分减治递归的思想进行快速幂的计算

Java(递归版)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public double myPow(double x, int n) {
if (n == 0) {
return 1.0;
}
if (n == 1) {
return x;
}
if (n < 0) {
n = -n; // 可能导致溢出
x = 1 / x;
}
int halfExpo = n / 2;
double halfRes = myPow(x, halfExpo);
double ans = halfRes * halfRes;
if ((n & 1) == 1) {
ans *= x;
}
return ans;
}

注意这个 Java 代码无法通过 50. Pow(x, n) - 力扣(LeetCode),原因是某个用例的 n 为 int 类型的最小值,导致取相反数时溢出。可以重载一个 n 为 long 的函数解决

⭐Go(位运算版)

解析

如我们想要计算 x15x^{15}, 观察发现 x15=x1x2x4x8x^{15}=x^{1} * x^{2} * x^{4} * x^{8},而 15 的二进制为 1111。即从低到高的第 ii 个(从 0 开始计数)二进制位若为 1,整体的结果就额外乘上 x2ix^{2^{i}}。这有些类似海明码的思想。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func myPow(x float64, n int) float64 {
// corner cases
if n == 0 {
return 1.0
}
if n < 0 {
return myPow(1/x, -n)
}
if n == 1 {
return x
}
ans := 1.0
tmp := x
for i := n; i > 0; i >>= 1 {
if i&1 != 0 {
ans *= tmp
}
tmp *= tmp
}
return ans
}

埃氏筛

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public int countPrimes(int n) {
int ans = 0;
boolean[] isPrime = new boolean[n];
Arrays.fill(isPrime, true);
for (int i = 2; i * i < n; i++) {
if (isPrime[i]) { // i is prime
for (int j = i * i; j < n; j += i) {
isPrime[j] = false;
}
}
}
for (int i = 2; i < n; i++) {
if (isPrime[i]) {
ans++;
}
}
return ans;
}

字符串哈希(字串问题)

使用 Rabin-Karp 算法检测 s1 中是否包含字串 s2(改编自 Go 内置的 bytealg package - internal/bytealg - Go Packages

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// adopted from bytealgo.go/IndexRabinKarp
const PrimeRK = 16777619

func RabinKarp(s, sep string) bool {
if len(s) < len(sep) {
return false
}
if len(s) == len(sep) {
return s == sep
}
hashss, pow := hashStr(sep)
n := len(sep)
var h uint32
for i := 0; i < n; i++ {
h = h*PrimeRK + uint32(s[i])
}
if h == hashss && s[:n] == sep {
return true
}
for i := n; i < len(s); {
h *= PrimeRK
h += uint32(s[i])
h -= pow * uint32(s[i-n])
i++
if h == hashss && s[i-n:i] == sep {
return true
}
}
return false
}
func hashStr(sep string) (uint32, uint32) {
hash := uint32(0)
for i := 0; i < len(sep); i++ {
hash = hash*PrimeRK + uint32(sep[i])
}
var pow, sq uint32 = 1, PrimeRK
// Fast Exponentiation
for i := len(sep); i > 0; i >>= 1 {
if i&1 != 0 {
pow *= sq
}
sq *= sq
}
return hash, pow
}

子序列

LCS

时间复杂度 O(mn)O(mn)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// if text2 is a subsequence of text1?
func isSubseq(text1, text2 string) bool {
if text1 == text2 {
return true
}
return len(text2) == longestCommonSubsequence(text1, text2)
}
func longestCommonSubsequence(text1 string, text2 string) int {
l1, l2 := len(text1), len(text2)
dp := make([][]int, l1+1)
for i := range dp {
dp[i] = make([]int, l2+1)
}
for i, x := range text1 {
for j, y := range text2 {
if x == y {
dp[i+1][j+1] = 1 + dp[i][j]
}
dp[i+1][j+1] = max(dp[i+1][j+1], dp[i][j+1], dp[i+1][j])
}
}
return dp[l1][l2]
}

⭐双指针

时间复杂度 O(m+n)O(m+n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// if text2 is a subsequence of text1?
func isSubseq(text1, text2 string) bool {
if len(text1) < len(text2) {
return false
}
if len(text1) == len(text2) {
return text1 == text2
}
i, j := 0, 0
for i < len(text1) && j < len(text2) {
if text1[i] == text2[j] {
i++
j++
} else {
i++
}
}
return j == len(text2)
}

数据结构和算法模板
https://exapricity.tech/Algo-Template.html
作者
Peiyang He
发布于
2024年3月4日
许可协议