栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

基于 Pytorch 实现 Federated Learning 中的安全聚合(基于模型参数)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

基于 Pytorch 实现 Federated Learning 中的安全聚合(基于模型参数)

基于 Pytorch 实现 Federated Learning 中的安全聚合(基于模型参数)

最近看了一些关于 FL 的安全聚合的文章,也找了一些代码,但是发现他们都有一些共同点——全是基于 FedSGD 的(原版基于FedSGD 的 github :https://github.com/shanxuanchen/attacking_federate_learning)。但是现在用 FedSGD 的太少了,收敛速度还慢。因此我修改了两个比较经典的安全聚合算法:krum 和 trimmed_median 去适应 FedAVG。
话不多说,直接上代码:

Krum:

def krum(w, args):# csdn 第二姿态,
    distances = defaultdict(dict)
    non_malicious_count = int((args.num_users - args.atk_num) * args.frac)
    num = 0
    for k in w[0].keys():
        if num == 0:
            for i in range(len(w)):
                for j in range(i):
                    distances[i][j] = distances[j][i] = np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy())
            num = 1
        else:
            for i in range(len(w)):
                for j in range(i):
                    distances[j][i] += np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy())
                    distances[i][j] += distances[j][i]
    minimal_error = 1e20
    for user in distances.keys():
        errors = sorted(distances[user].values())
        current_error = sum(errors[:non_malicious_count])
        if current_error < minimal_error:
            minimal_error = current_error
            minimal_error_index = user
    return w[minimal_error_index]

Trimmed_median:

def trimmed_mean(w, args): # csdn 第二姿态,
    number_to_consider = int((args.num_users - args.atk_num) * args.frac) - 1
    print(number_to_consider)
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        tmp = []
        for i in range(len(w)):
            tmp.append(w[i][k].cpu().numpy()) # get the weight of k-layer which in each client
        tmp = np.array(tmp)
        med = np.median(tmp,axis=0)
        new_tmp = []
        for i in range(len(tmp)):# cal each client weights - median
            new_tmp.append(tmp[i]-med)
        new_tmp = np.array(new_tmp)
        good_vals = np.argsort(abs(new_tmp),axis=0)[:number_to_consider]
        good_vals = np.take_along_axis(new_tmp, good_vals, axis=0)
        k_weight = np.array(np.mean(good_vals) + med)
        w_avg[k] = torch.from_numpy(k_weight).to(args.device)
    return w_avg

如果有不明白的参数可以继续在评论区交流!!!

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/738296.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号