pytorch模型剪枝学习笔记

pytorch代码仓库

pytorch在19年11月份的时候合入了这部分剪枝的代码。pytorch提供一些直接可用的api,用户只需要传入需要剪枝的module实例和需要剪枝的参数名字,系统自动帮助完成剪枝操作,看起来接口挺简单。比如 def random_structured(module, name, amount, dim)

pytorch支持的几种类型的剪枝策略:

pytorch模型剪枝学习笔记

详细分析

  • pytorch提供了一个剪枝的抽象基类‘‘class BasePruningMethod(ABC)’,所有剪枝策略都需要继承该基类,并重载部分函数就可以了

  • 一般情况下需要重载init和compute_mask方法,call, apply_mask, apply, prune和remove不需要重载,例如官方提供的RandomUnstructured剪枝方法 pytorch模型剪枝学习笔记

  • 基类实现的6个方法: pytorch模型剪枝学习笔记

  • 剪枝的API接口,可以看到支持用户自定义的剪枝mask,接口为custom_from_mask pytorch模型剪枝学习笔记

  • API的实现,使用classmethod的方法,剪枝策略的实例化在框架内部完成,不需要用户实例化

  • 剪枝的大只过程:

    1. 根据用户选择的剪枝API生成对应的策略实例,此时会判断需要做剪枝操作的module上是否已经挂有前向回调函数,没有则生成新的,有了就在老的上面添加,并且生成PruningContainer。从这里可以看出,对于同一个module使用多个剪枝策略时,pytorch通过PruningContainer来对剪枝策略进行管理。PruningContainer本身也是继承自BasePruningMethod。同时设置前向计算的回调,便于后续训练时调用。
    2. 接着根据用户输入的module和name,找到对应的参数tensor。如果是第一次剪枝,那么需要生成_orig结尾的tensor,然后删除原始的module上的tensor。如name为bias,那么生成bias_orig存起来,然后删除module.bias属性。
    3. 获取defaultmask,然后调用method.computemask生成当前策略的mask值。生成的mask会被存在特定的缓存module.register_buffer(name + "_mask", mask)。这里的compute_mask可能是两种情况:如果只有一个策略,那么调用的时候对应剪枝策略的compute_mask方法,如果一个module有多个剪枝策略组合,那么调用的应该是PruningContainer的compute_mask pytorch模型剪枝学习笔记
    4. 执行剪枝,保存剪枝结果到module的属性,注册训练时的剪枝回调函数,剪枝完成。新的mask应用在orig的tensor上面生成新的tensor保存的对应的name属性 pytorch模型剪枝学习笔记
  • remove接口 pytorch还提供各类一个remove接口,目的是把之前的剪枝结果持久化,具体操作就是删除之前生成的跟剪枝相关的缓存或者是回调hook接口,设置被剪枝的name参数(如bias)为最后一次训练的值。 pytorch模型剪枝学习笔记

  • 自己写一个剪枝策略接口也是可以的: pytorch模型剪枝学习笔记

    1. 先写一个剪枝策略类继承BasePruningMethod
    2. 然后重载基类的compute_mask方法,写自己的计算mask方法

官方完整教程在这里