(19)国家知识产权局
(12)发明 专利申请
(10)申请公布号
(43)申请公布日
(21)申请 号 202111592650.6
(22)申请日 2021.12.23
(71)申请人 北京百度网讯科技有限公司
地址 100085 北京市海淀区上地十街10号
百度大厦二层
(72)发明人 刘吉 余孙婕 窦德景 周吉文
(74)专利代理 机构 北京清亦华知识产权代理事
务所(普通 合伙) 11201
专利代理师 杜月
(51)Int.Cl.
G06N 20/00(2019.01)
G06F 21/62(2013.01)
G06N 3/04(2006.01)
G06N 3/08(2006.01)
G06K 9/62(2022.01)G06V 10/764(2022.01)
G06V 10/82(2022.01)
(54)发明名称
联邦学习模型的生成方法及其装置
(57)摘要
本公开提供了联邦学习模型的生成方法及
其装置, 涉及人工智能技术领域中的深度学习和
联邦学习技术领域。 具体实现方案为: 获取图片,
获取图片的分类结果, 以及根据图片和分类结果
对待训练的联邦学习模型进行训练, 以生成训练
好的联邦学习模型, 其中, 待训练的联邦学习模
型为待剪枝的联邦学习模型经过剪枝处理后得
到的, 在剪枝处理过程中待剪枝的联邦学习模型
中卷积层的剪枝率是根据模型精度自动调整的。
从而加速模型训练、 有效地减少资源占用, 生成
能够更好的适应资源有 限的边缘化使用场景的
模型, 且根据待剪枝的联邦学习模 型的精度自动
调整模型中卷积层的剪枝率, 无需人工选择参
数, 实现了自适应 剪枝。
权利要求书3页 说明书9页 附图6页
CN 114492831 A
2022.05.13
CN 114492831 A
1.一种联邦学习模型的生成方法, 包括:
获取图片;
获取所述图片的分类结果; 以及
根据所述图片和所述分类结果对待训练的联邦学习 模型进行训练, 以生成训练好的联
邦学习模型, 其中, 所述待训练的联邦学习模型为待剪枝的联邦学习模型经过剪枝处理后
得到的, 在剪枝处理过程中所述待剪枝的联邦学习模型中卷积层的剪枝率是根据模型精度
自动调整的。
2.根据权利要求1所述的生成方法, 还 包括:
获取客户端发送的模型 更新梯度;
根据所述模型 更新梯度更新所述待剪枝的联邦学习模型;
响应于当前轮为回滚轮, 则计算更新后的待剪枝的联邦学习模型的所述模型精度;
响应于所述模型精度低于最近一 次剪枝后的待剪枝的联邦学习 模型的模型精度, 则确
定最近一次剪枝不 合理;
响应于剪枝未完成, 则将所述更新后的待剪枝的联邦学习 模型回滚至最近一 次剪枝前
的待剪枝的联邦学习模型, 并降低最近一次剪枝对应的卷积层的剪枝率, 将所述最近一次
剪枝前的待剪枝的联邦学习模型发送至所述客户端, 以供所述客户端根据接收到的待剪枝
的联邦学习模型重新 生成所述模型 更新梯度;
响应于剪枝已完成, 则将所述更新后的待剪枝的联邦学习 模型确定为所述待训练的联
邦学习模型。
3.根据权利要求2所述的生成方法, 其中, 所述降低 最近一次剪枝对应的卷积层的剪枝
率, 包括:
将所述最近一次剪枝对应的卷积层的剪枝率降低一半。
4.根据权利要求2所述的生成方法, 还 包括:
响应于降低后的最近一 次剪枝对应的卷积层的剪枝率低于预设的剪枝率阈值, 则将所
述降低后的最近一次剪枝对应的卷积层的剪枝率确定为所述剪枝率阈值。
5.根据权利要求2所述的生成方法, 还 包括:
响应于所述模型精度等于或者高于最近一次剪枝后的待剪枝的联邦学习模型的模型
精度, 则确定最近一次剪枝合理, 并将所述更新后的待剪枝的联邦学习模型发送至所述客
户端, 以供 所述客户端根据接收到的待剪枝的联邦学习模型重新 生成所述模型 更新梯度。
6.根据权利要求2所述的生成方法, 还 包括:
响应于当前轮不为 回滚轮且当前轮不为剪枝轮, 则将所述更新后的待剪枝的联邦学习
模型发送至所述客户端, 以供所述客户端根据接收到的待剪枝的联邦学习模型重新生成所
述模型更新梯度。
7.根据权利要求2所述的生成方法, 还 包括:
响应于当前轮不为 回滚轮且当前轮为剪枝轮, 则根据当前轮对应的卷积层的剪枝率对
所述更新后的待剪枝的联邦学习模型进行剪枝, 并将剪枝后的待剪枝的联邦学习模型发送
至所述客户端, 以供所述客户端根据接收到的待剪枝的联邦学习模型重新生成所述模型更
新梯度。
8.一种联邦学习模型的生成装置, 包括:权 利 要 求 书 1/3 页
2
CN 114492831 A
2第一获取模块, 用于获取图片;
第二获取模块, 用于获取 所述图片的分类结果; 以及
训练模块, 用于根据所述图片和所述分类结果对待训练的联邦学习模型进行训练, 以
生成训练好的联邦学习模型, 其中, 所述待训练的联邦学习模型为待剪枝的联邦学习模型
经过剪枝处理后得到的, 在剪枝处理过程中所述待剪枝的联邦学习模型中卷积层的剪枝率
是根据模型精度自动调整的。
9.根据权利要求8所述的生成装置, 还 包括:
第三获取模块, 用于获取客户端发送的模型 更新梯度;
更新模块, 用于根据所述模型 更新梯度更新所述待剪枝的联邦学习模型;
计算模块, 用于响应于当前轮为回滚轮, 则计算更新后的待剪枝的联邦学习模型的所
述模型精度;
第一确定模块, 用于响应于所述模型精度低于最近一 次剪枝后的待剪枝的联邦学习 模
型的模型精度, 则确定最近一次剪枝不 合理;
回滚模块, 用于响应于剪枝未完成, 则将所述更新后的待剪枝的联邦学习模型回滚至
最近一次剪枝前 的待剪枝的联邦学习模型, 并降低最近一次剪枝对应的卷积层的剪枝率,
将所述最近一次剪枝 前的待剪枝的联邦学习模型发送至所述客户端, 以供所述客户端根据
接收到的待剪枝的联邦学习模型重新 生成所述模型 更新梯度;
第二确定模块, 用于响应于剪枝已完成, 则将所述更新后的待剪枝的联邦学习模型确
定为所述待训练的联邦学习模型。
10.根据权利要求9所述的生成装置, 其中, 所述回滚 模块, 包括:
降低单元, 用于将所述 最近一次剪枝对应的卷积层的剪枝率降低一半。
11.根据权利要求9所述的生成装置, 还 包括:
第三确定模块, 用于响应于降低后的最近一 次剪枝对应的卷积层的剪枝率低于预设的
剪枝率阈值, 则将所述降低后的最近一次剪枝对应的卷积层的剪枝率确定为所述剪枝率阈
值。
12.根据权利要求9所述的生成装置, 还 包括:
第四确定模块, 用于响应于所述模型精度等于或者高于最近一 次剪枝后的待剪枝的联
邦学习模型 的模型精度, 则确定最近一次剪枝合理, 并将所述更新后的待剪枝的联邦学习
模型发送至所述客户端, 以供所述客户端根据接收到的待剪枝的联邦学习模型重新生成所
述模型更新梯度。
13.根据权利要求9所述的生成装置, 还 包括:
发送模块, 用于响应于当前轮不为回滚轮且当前轮不为剪枝轮, 则将所述更新后的待
剪枝的联邦学习模型发送至所述客户端, 以供所述客户端根据接收到的待剪枝的联邦学习
模型重新 生成所述模型 更新梯度。
14.根据权利要求9所述的生成装置, 还 包括:
剪枝模块, 用于响应于当前轮不为回滚轮且当前轮为剪枝轮, 则根据当前轮对应的卷
积层的剪枝率对所述更新后的待剪枝的联邦学习模型进行剪枝, 并将剪枝后的待剪枝的联
邦学习模型发送至所述客户端, 以供所述客户端根据接收到的待剪枝的联邦学习模型重新
生成所述模型 更新梯度。权 利 要 求 书 2/3 页
3
CN 114492831 A
3
专利 联邦学习模型的生成方法及其装置
文档预览
中文文档
19 页
50 下载
1000 浏览
0 评论
309 收藏
3.0分
温馨提示:本文档共19页,可预览 3 页,如浏览全部内容或当前文档出现乱码,可开通会员下载原始文档
本文档由 人生无常 于 2024-03-18 23:18:11上传分享