union-find算法

问题描述

有若干数量的节点,如果A节点与B节点相连接,则表明A与B为同类节点,且满足传递性,求出有多少种类别的节点;针对大量节点的情况给出最优的解法

解决方案

算法第4版P136提到

将问题转化为类表示

  • 用一个列表表示所有的节点n,其值表示节点的类别,初始化时为每个节点赋不同的值
  • union函数将给定的两节点的标识符置为相同值,同时也要将这两个节点的同类节点置为该值,即将两类节点置为同一类
  • find函数获取节点的标识符,即类别
  • connected函数判断节点是否连通
  • count函数计算类别数

如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#其中id列表表示的是节点对应的标识符,即类别;
class UF(object):
#初始化分量
def __init__(self,n):
self.count=n
self.id=[]
for i in range(n):
self.id.append(i)
#将p以及和p标识相同的节点和q以及q标识相同的节点置为同样的值
def union(self,p,q):
pass
#查找p的标识符
def find(self,p):
pass
#判断p和q是否连通
def connected(self,p,q):
return self.find(p)==self.find(q)
#计算不同种类的个数
def count(self):
return self.count

优化的关键在于find函数和union函数

quick-find算法

find函数直接返回标识符,union通过该标识符对列表全部扫描,修改标识符

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#将p以及和p标识相同的节点和q以及q标识相同的节点置为同样的值
def union(self,p,q):
pid= self.find(p)
qid= self.find(q)
#如果相等,则不需要修改
if pid==qid:
return
#否则找出所有标识符为qid的,全部改为pid,也可以反之,此处没有规定,见下面的加权算法
for i in range(len(self.id)):
if self.id[i]==qid :
self.id[i]=pid
self.count=self.count-1
#查找p的标识符
def find(self,p):
return self.id[p]

试想一下,如果给出的节点量非常大,在每次执行union时都要进行整个列表的扫描,很显然是不可取的

quick-union算法

使用树,如果两个节点相连通,只需要将一节点的根节点作为到另一节点的根节点的子节点;不改变每个节点的标识符,只是修改其对应的根节点,这样每个节点存储的就是父节点的下标;只有根节点的下标等于其存储的值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#获取根节点的标识符,只有根节点才满足下标等于其值,其它节点都存储的是父节点的下标
def find(self,p):
while self.id[p]!=p:
p=self.id[p]
return p
#
def union(self,p,q):
pid=self.find(p)
qid=self.find(q)
#同一根节点,则返回
if pid==qid:
return
self.id[qid]=pid#将下标为qid的节点的值置为pid,即其父类为pid节点,也可以反之
self.count=self.count-1#种类减1

union-find算法(加权quick-union算法)

上述quick-union算法的最坏的情况是:1号节点的父节点是2号,2号节点的父节点是3号,3号节点的父节点是4号,依次至最后,树的深度变为节点的个数,这样并没有降低查找的次数;针对于此,做如下改变:对于union函数每次在寻找到两个根节点后,判断其含有的节点个数,将小树的根节点添加到大树的根节点下面,就可以降低树的深度
这样需要为每个节点添加一个变量,代表以该节点为根节点组成树的节点数
完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class UF(object):
#初始化分量
def __init__(self,n):
self.count=n
self.id=[]
self.sz=[]
for i in range(n):
self.id.append(i)
self.sz.append(1)
#获取根节点的标识符,只有根节点才满足下标等于其值,其它节点都存储的是父节点的下标
def find(self,p):
while self.id[p]!=p:
p=self.id[p]
return p
def union(self,p,q):
pid=self.find(p)
qid=self.find(q)
#同一根节点,则返回
if pid==qid:
return
#具体变化在于此
if self.sz[p]>self.sz[q]:
self.id[q]=pid
self.sz[p]=self.sz[p]+self.sz[q]
else:
self.id[p]=qid
self.sz[q] = self.sz[p] + self.sz[q]
self.count=self.count-1#种类减1
#查找p的标识符
def find(self,p):
return self.id[p]
#判断p和q是否连通
def connected(self,p,q):
return self.find(p)==self.find(q)
#计算不同种类的个数
def count(self):
return self.count

如果觉得有帮助,给我打赏吧!