• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

pytorchregister_hook以和register_forward_hook

武飞扬头像
宇宙小菜菜
帮助2

何为叶子节点和非叶子节点

在理解register_hook之前,首先得搞懂什么叶子节点和非叶子节。简单来说叶子节点是有梯度且独立得张量,例如a = torch.tensor(2.0,requires_grad=True),b= torch.tensor(3.0,requires_grad=True),非叶子节点是依赖其他张量而得到得张量如c = a b
判断是叶子节点还是非叶子节点可以使用 is_leaf来判断一个张量是叶子节点还是非叶子节点。

import torch
a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.is_leaf)
print(b.is_leaf)
c = a  b 
print(c.is_leaf)

>>> True
>>> True
>>> False

中间张量 c 作为非叶子节点是没有梯度信息得。pytorch默认在梯度反向传播过程中不会记录中间变量梯度信息。而且叶子节点的梯度信息在反向传播流过程中是不允许我们修改的。只能通过print(a.grad)查看张量的梯度信息。
那么,如果我们想查看中间变量 c 以及想改变叶子节点反向传播过程中的梯度值,应该怎么办呢。这时候就要使用register_hook这个钩子函数了。通过一下两段代码看一下钩子函数的主要作用。

register_hook

a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.grad)
print(b.grad)
c = a*b
print(c.grad)  # 由于c是叶子节点,所以他是不记录梯度信息得。前后打印梯度信息都为None

d = torch.tensor(4.0,requires_grad=True)
e = c * d
e.backward()
print(a.grad)
print(b.grad)
print(c.grad)

>>>输出
None
None
None
tensor(12.)
tensor(8.)
None
学新通

通过上面代码可以看出,c作为中间变量在反向传播过程中不记录梯度信息。c=a*b其中a的梯度就为b的值,b的梯度就是a的值。接下来对中间变量c 使用register_hook,这个函数传入的参数得是一个函数。

import torch

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b

def c_hook(grad):
    print("c_hook",grad)
    return grad   2    # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。

# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad()   # 存储中间变量的梯度

print(a.grad)
print(b.grad)
print(c.grad)

c.backward()

print(a.grad)
print(b.grad)
print(c.grad)

>>>
None
None
None
c_hook tensor(1.)
hello my grad is tensor(3.)
tensor(9.)
tensor(6.)
tensor(3.)
学新通

为什么输出会是这样的结果呢,一个张量可以注册多个钩子函数,反向传播过程中按照注册的顺序依次运行。 c.register_hook(c_hook) c.register_hook(lambda grad:)
,这两个函数可以重写c的梯度,第一个函数传入的参数是c的梯度,自身对自身的梯度pytorch中默认为1。所以此时c_hook中传入的grad=1,这个函数返回值为grad 2=3,此时会重写中间变量c的梯度信息。第二个钩子函数传入的函数为匿名函数,这个匿名函数对c的梯度没有进行重写,使用的还是上一个钩子函数重写的值,此使打印信息就为3。最后通过c.retain_grad()记c的梯度信息。通过这个例子,我稍微懂了点register_hook这个钩子函数的作用,是不是本来不可修改的梯度信息值,通过这个函数修改了呢。

通过一下这个例子比较再来看一下registe_hook函数的作用。

import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b


def c_hook(grad):
    print("c_hook",grad)
    return grad   2    # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。

# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad()   # 存储中间变量的梯度

d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad   100)  # 将使用100 grad代替本来返回得梯度值

e = c * d

print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)


# e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()

e.backward()

print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)

>>>输出
None
None
None
None
None
c_hook tensor(8.)
hello my grad is tensor(10.)
tensor(30.)
tensor(20.)
tensor(10.)
tensor(112.)
tensor(2.)
学新通

register_forward_hook

register_forward_hook register_forward_pre_hook这个函数主要使用在nn.Module网络中。
第一个函数看名称是用在网络forward之前,第二个是运行在forward之后,举例:

import torch
import torch.nn as nn


class SumNet(nn.Module):
    def __init__(self):
        super(SumNet, self).__init__()

    @staticmethod
    def forward(a, b, c):
        d = a   b   c

        print('forward():')
        print('    a:', a)
        print('    b:', b)
        print('    c:', c)
        print()
        print('    d:', d)
        print()

        return d


def forward_pre_hook(module, input_positional_args):
    a, b, c = input_positional_args
    new_input_positional_args = a   10, b,c 10

    print('forward_pre_hook():')
    print('    module:', module)
    print('    input_positional_args:', input_positional_args)
    print()
    print('    new_input_positional_args:', new_input_positional_args)
    print()

    return new_input_positional_args


def forward_hook(module, input_positional_args, output):
    new_output = output   100

    print('forward_hook():')
    print('    module:', module)
    print('    input_positional_args:', input_positional_args)
    print('    output:', output)
    print()
    print('    new_output:', new_output)
    print()

    return new_output


def main():
    sum_net = SumNet()
    sum_net.register_forward_pre_hook(forward_pre_hook)
    sum_net.register_forward_hook(forward_hook)

    a = torch.tensor(1.0, requires_grad=True)
    b = torch.tensor(2.0, requires_grad=True)
    c = torch.tensor(3.0, requires_grad=True)

    print('start')
    print()
    print('a:', a)
    print('b:', b)
    print('c:', c)
    print()
    
    print('before model')
    print()

    d = sum_net(a, b, c)   # 前向传播得时候钩子函数起作用了,先是forward_pre_hook,接下来是forward,接下来是forward_hook函数。

    print('after model')
    print()
    print('d:', d)


if __name__ == '__main__':
    main()
学新通

输出信息:

start

a: tensor(1., requires_grad=True)
b: tensor(2., requires_grad=True)
c: tensor(3., requires_grad=True)

before model

forward_pre_hook():
    module: SumNet()
    input_positional_args: (tensor(1., requires_grad=True), tensor(2., requires_grad=True), tensor(3., requires_grad=True))

    new_input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))

forward():
    a: tensor(11., grad_fn=<AddBackward0>)
    b: tensor(2., requires_grad=True)
    c: tensor(13., grad_fn=<AddBackward0>)

    d: tensor(26., grad_fn=<AddBackward0>)

forward_hook():
    module: SumNet()
    input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))
    output: tensor(26., grad_fn=<AddBackward0>)

    new_output: tensor(126., grad_fn=<AddBackward0>)

after model

d: tensor(126., grad_fn=<AddBackward0>)
学新通

分析以上为什么会输出这样的结果,前面提到register_forward_hook这个函数会在网络前向传播前运行,需要两个参数modul 和 input案例中输入为 tensor 1 2 3,经过这个函数给2 3 分别加了10,并且返回了一组新的值,这组值是要传入forward中,可以看出,forward函数打印的a b c 为传入的这组新值,而不是刚开始定义的1 2 3,forward函数运行过程中返回每层的输出会运行forward_hook函数。这个函数主要需要三个参数,module input output
以下从Lenet网络来使用这个函数:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)     
        out = F.max_pool2d(out, 2)      
        
        out = self.conv2(out)
        out = F.relu(out)  
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out
model = LeNet()

# 分别对model的第一个卷积层和最后一层使用了钩子函数,这样既可以取出对应层的输出。
def hook(model,input_,output):
    print("最后一层输出:",output.shape)

def conv_hook(model,input_,output):
    print("conv1后",input_[0].shape,output.shape)

model.register_forward_hook(hook)
model.conv1.register_forward_hook(conv_hook)


img = torch.randn([1,3,32,32])
out_put = model(img)

>>>
conv1后 torch.Size([1, 3, 32, 32]) torch.Size([1, 6, 28, 28])
最后一层输出: torch.Size([1, 10])
学新通

基于上可以看出给不同层使用钩子函数,可以提取出每一层的输出,并进行相应的处理。

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhiahebb
系列文章
更多 icon
同类精品
更多 icon
继续加载