pytorch模型剪枝学习笔记
pytorch在19年11月份的时候合入了这部分剪枝的代码。pytorch提供一些直接可用的api,用户只需要传入需要剪枝的module实例和需要剪枝的参数名字,系统自动帮助完成剪枝操作,看起来接口挺简单。比如 def random_structured(module, name, amount, dim)
pytorch支持的几种类型的剪枝策略:
详细分析
-
pytorch提供了一个剪枝的抽象基类‘‘class BasePruningMethod(ABC)’,所有剪枝策略都需要继承该基类,并重载部分函数就可以了
-
一般情况下需要重载init和compute_mask方法,call, apply_mask, apply, prune和remove不需要重载,例如官方提供的RandomUnstructured剪枝方法
-
基类实现的6个方法:
-
剪枝的API接口,可以看到支持用户自定义的剪枝mask,接口为custom_from_mask
-
API的实现,使用classmethod的方法,剪枝策略的实例化在框架内部完成,不需要用户实例化
-
剪枝的大只过程:
- 根据用户选择的剪枝API生成对应的策略实例,此时会判断需要做剪枝操作的module上是否已经挂有前向回调函数,没有则生成新的,有了就在老的上面添加,并且生成PruningContainer。从这里可以看出,对于同一个module使用多个剪枝策略时,pytorch通过PruningContainer来对剪枝策略进行管理。PruningContainer本身也是继承自BasePruningMethod。同时设置前向计算的回调,便于后续训练时调用。
- 接着根据用户输入的module和name,找到对应的参数tensor。如果是第一次剪枝,那么需要生成_orig结尾的tensor,然后删除原始的module上的tensor。如name为bias,那么生成bias_orig存起来,然后删除module.bias属性。
- 获取defaultmask,然后调用method.computemask生成当前策略的mask值。生成的mask会被存在特定的缓存module.register_buffer(name + "_mask", mask)。这里的compute_mask可能是两种情况:如果只有一个策略,那么调用的时候对应剪枝策略的compute_mask方法,如果一个module有多个剪枝策略组合,那么调用的应该是PruningContainer的compute_mask
- 执行剪枝,保存剪枝结果到module的属性,注册训练时的剪枝回调函数,剪枝完成。新的mask应用在orig的tensor上面生成新的tensor保存的对应的name属性
-
remove接口 pytorch还提供各类一个remove接口,目的是把之前的剪枝结果持久化,具体操作就是删除之前生成的跟剪枝相关的缓存或者是回调hook接口,设置被剪枝的name参数(如bias)为最后一次训练的值。
-
自己写一个剪枝策略接口也是可以的:
- 先写一个剪枝策略类继承BasePruningMethod
- 然后重载基类的compute_mask方法,写自己的计算mask方法