python __call__

  1. python中的__call__

在python中一切皆为对象,而对象可以分为可被调用不可被调用。可被调用的对象意思是可以在该对象后进行()操作。比如

class Student(object):
    def __init__(self, name):
        self.name = name
        print(self.name)
        
def fun():
    print("a function")
s = Student("name")
fun()

print(callable(Student))
print(callable(fun))
print(callable(s))

输出结果
xiao ming
a function
True
True
False

可以看到上例中函数fun和类Student后面都可以加一个括号,所以他们为可被调用对象,而类C的实例c却不能在其后加()操作。

想要一个类的实例成为可调用对象只需要在类中实现__call__函数即可。
修改上述Student类代码为:

class Student(object):
    def __init__(self, name):
        self.name = name
        print(self.name)
        
    def __call__(self):
        print(self.name + " in call function")

s = Student("xiao ming")
s()

print(callable(Student))
print(callable(s))

输出
xiao ming
xiao ming in call function
True
True

可以看到使用s()后,相当于调用了s.call() (注:call前后有两个下划线,不知道为什么这编辑器打不出来。。。)。并且实例s也成为了可调用对象。

  1. pytorch

pytorch中网络的基类torch.nn.Module类实现了__call__函数

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        # nn.Module子类的函数必须在构造函数中执行父类的构造函数
        # 下式等价于nn.Module.__init__(self)
        super(Net, self).__init__()
        
        # 卷积层 '1'表示输入图片为单通道, '6'表示输出通道数,'5'表示卷积核为5*5
        self.conv1 = nn.Conv2d(1, 6, 5) 
        # 卷积层
        self.conv2 = nn.Conv2d(6, 16, 5) 
        # 仿射层/全连接层,y = Wx + b
        self.fc1   = nn.Linear(16*5*5, 120) 
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x): 
        # 卷积 -> ** -> 池化 
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2) 
        # reshape,‘-1’表示自适应
        x = x.view(x.size()[0], -1) #x.size()[0]是batch的大小,所以x = [batch_size, 400]
        x = F.relu(self.fc1(x))  #x_size = batch*120
        x = F.relu(self.fc2(x))  #x_size = batch*84
        x = self.fc3(x)             #x_size = batch*10
        return x


net = Net()
input = t.randn(1, 1, 32, 32)
out = net(input)

上面的Net类继承了torch.nn.Module,后者为所有网络中的基类。Net类定义了两个成员函数,一个为构造函数,另一个为重写了Modual类的前向传播forward函数。forward函数把输入x输入网络,返回前向传播的计算结果。

类外定义了实例化了一个网络net,随机生成了一个4个维度的input来模仿输入图片,并把作为参数调用net(input)。注意net为网络的一个实例,并不是一个类,也不是一个函数。父类torch.nn.Module类中实现__call__函数时调用了forward函数。
python __call__

这样做方便了使用者,只需要

out = net(input)

即可运用forward函数计算出前向传播的结果。