设计一个支持在平均 时间复杂度 O(1) 下, 执行以下操作的数据结构。
注意: 允许出现重复元素。
集合包含以下三个功能:
A.insert(val):向集合中插入元素 val。
B.remove(val):当 val 存在时,从集合中移除一个 val。
C.getRandom:从现有集合中随机获取一个元素。每个元素被返回的概率应该与其在集合中的数量呈线性相关。
集合功能示例:
二.实现集合类型定义为 RanomizedSet,实现了上述三个方法即插入元素,移除元素以及随机获取与元素个数线性相关的元素且时间复杂度为 O(1)。这里保存集合元素采用了 list,判断元素是否存在则使用了 dict 辅助,所以虽然该数据结构上述操作的时间复杂度为 O(1),但是为此空间复杂度也随之增加。
#!/usr/bin/python
# -*- coding: UTF-8 -*-
from random import choice
# defaultdict(list)
class RandomizedSet:
def __init__(self):
"""
Initialize your data structure here.
"""
self.dict = {}
self.list = []
def insert(self, val: int) -> bool:
"""
Inserts a value to the set. Returns true if the set did not already contain the specified element.
"""
self.list.append(val)
if val in self.dict:
self.dict[val].append(len(self.list) - 1)
return False
else:
self.dict[val] = []
self.dict[val].append(len(self.list) - 1)
return True
def remove(self, val: int) -> bool:
"""
Removes a value from the set. Returns true if the set contained the specified element.
"""
if val in self.dict and self.dict[val] != []:
# 交换最后一个元素与待删除元素
last_idx = len(self.list) - 1
val_idx = self.dict[val][0]
# 相同则无需交换
if last_idx != val_idx:
self.list[val_idx], self.list[last_idx] = self.list[last_idx], self.list[val_idx]
self.dict[self.list[val_idx]].remove(last_idx)
self.dict[self.list[val_idx]].append(val_idx)
self.list.pop()
self.dict[val].remove(val_idx)
# 如果对应 val 的位置列表为空则删除
if not self.dict[val]:
del self.dict[val]
return True
return False
def getRandom(self) -> int:
"""
Get a random element from the set.
"""
return choice(self.list)
三.详解
1.insert(val)
def insert(self, val: int) -> bool:
"""
Inserts a value to the set. Returns true if the set did not already contain the specified element.
"""
self.list.append(val)
if val in self.dict:
self.dict[val].append(len(self.list) - 1)
return False
else:
self.dict[val] = []
self.dict[val].append(len(self.list) - 1)
return True
向集合中插入元素,由于要求中提到允许出现重复的元素,所以每个元素 val 的位置索引使用 [] 记录,这里原题目中给出的代码有问题,无法添加重复元素。List 中添加新元素 val 后,存储 val 在列表中对应的位置,这里通过 dict 保证一个元素的位置存储在一个列表中。返回值通过 dict 判断,如果 dict 存在则返回 False,否则返回 True。
2.remove(val)def remove(self, val: int) -> bool:
"""
Removes a value from the set. Returns true if the set contained the specified element.
"""
if val in self.dict and self.dict[val] != []:
# 交换最后一个元素与待删除元素
last_idx = len(self.list) - 1
val_idx = self.dict[val][0]
# 相同则无需交换
if last_idx != val:
self.list[val_idx], self.list[last_idx] = self.list[last_idx], self.list[val_idx]
self.dict[self.list[val_idx]].remove(last_idx)
self.dict[self.list[val_idx]].append(val_idx)
self.list.pop()
self.dict[val].remove(val_idx)
# 如果对应 val 的位置列表为空则删除
if not self.dict[val]:
del self.dict[val]
return True
return False
移除元素中的元素,这里是这个 collection 最复杂的地方,这里的思路类似于冒泡排序一样。
A.获取最后一个元素的值和要删除值的索引
获取最后一个元素的位置和要删除元素的位置
last_idx = len(self.list) - 1 val_idx = self.dict[val][0]
B.删除值位置的元素替换为列表最后一个值,并修改其在 dict 的索引位置
如果两个索引一致则无需交换位置直接 pop 删除即可;交换后的新值索引由 last_idx 变换到 val_idx,所以 remove last_idx,添加 val_idx。最后列表 pop 弹出元素并且删除 val 对应的索引。
if last_idx != val_idx:
self.list[val_idx], self.list[last_idx] = self.list[last_idx], self.list[val_idx]
self.dict[self.list[val_idx]].remove(last_idx)
self.dict[self.list[val_idx]].append(val_idx)
self.list.pop()
self.dict[val].remove(val_idx)
C.清除空列表
如果对应 val 值的索引列表为空,则无需为该值记录索引故删除。
if not self.dict[val]:
del self.dict[val]
D.返回值
包含元素并移除返回 True,否则返回 False。
3.getRandom(val)def getRandom(self) -> int:
"""
Get a random element from the set.
"""
return choice(self.list)
这里 choice 函数来自 random 库,需要 from random import choice 引入,这里如果想自己实现也很简单,只需要随机 [0, len(list)-1] 的数字并从 list 取值即可,随机到每个索引的概率为 1/n,n 为列表长度,所以 p(val) = 1/n * num,num 为该数字在集合中的个数,因此 p(val) 满足概率与其在集合中的数量呈线性相关。
def getRandomV2(self) -> int:
"""
Get a random element from the set.
"""
return self.list[random.randint(0, len(self.list) - 1)]
四.测试
1.添加、移除元素
if __name__ == '__main__':
selfSet = RandomizedSet()
print("添加元素...")
selfSet.insert(1)
selfSet.insert(1)
selfSet.insert(2)
selfSet.insert(3)
selfSet.insert(3)
print(selfSet.dict)
print(selfSet.list)
print("移除元素...")
selfSet.remove(1)
selfSet.remove(2)
selfSet.remove(3)
print(selfSet.dict)
print(selfSet.list)
2.成比例采样
print("添加元素... 1,1,2,3")
selfSet.insert(1)
selfSet.insert(1)
selfSet.insert(2)
selfSet.insert(3)
print("Random Origin")
countMap = {}
for i in range(100):
random_num = selfSet.getRandom()
if random_num not in countMap:
countMap[random_num] = 0
countMap[random_num] += 1
print(countMap)
print("Random V1")
countMap = {}
for i in range(100):
random_num = selfSet.getRandomV2()
if random_num not in countMap:
countMap[random_num] = 0
countMap[random_num] += 1
print(countMap)



