博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PytorchZerotoAll学习笔记(三)--自动求导
阅读量:5203 次
发布时间:2019-06-13

本文共 1440 字,大约阅读时间需要 4 分钟。

Pytorch给我们提供了自动求导的函数,不用再自己再推导计算梯度的公式了

虽然有了自动求导的函数,但是这里我想给大家浅析一下:深度学习中的一个很重要的反向传播

references:

我们先来看看什么是chain- rule(链式法则)

             Z是由 y经过f函数得到的,y又是x经过g函数得到 

              

                                                 ,    

                                                              

                     正向传播的方向是从左往右,那么反向传播的便是从右到左,梯度是一级级往回传递的

                    我们知道一般输出的时候都要经过一个激活函数,常用的是relu。当前的结果要往后传,

                    那么,这个时候便是函数的复合,一个套一个(俄罗斯套娃)(正向传播)

                    反向传播,就像剥洋葱,一层一层,你会发现它是没有心的.........哈哈哈哈哈哈

 

tips:当前层的梯度的计算需要后一层计算的梯度的结果

 

我们再来看看代码

import torchfrom torch.autograd import Variablex_data = [1.0, 2.0, 3.0]y_data = [2.0, 4.0, 6.0] #这里的w,我们是用tensor来生成了,不再是一个python的变量,调用torch.Tensor  需要计算梯度,所以 requires——grad设置为truew = Variable(torch.Tensor([1.0]),  requires_grad=True)  # Any random value# our model forward passdef forward(x):    return x * w# Loss functiondef loss(x, y):    y_pred = forward(x)    return (y_pred - y) * (y_pred - y)# Before trainingprint("predict (before training)",  4, forward(4).data[0])# Training loopfor epoch in range(10):    for x_val, y_val in zip(x_data, y_data):        l = loss(x_val, y_val) # l.backward()调用这个函数就可以让程序自动求梯度啦,是不是很神奇!        l.backward() # 获取梯度的数值 使用 .data直接调用其属性即可        print("\tgrad: ", x_val, y_val, w.grad.data[0])        w.data = w.data - 0.01 * w.grad.data        #手动清零 这里我们迭代10轮,所以下次计算之前都要清零当前的梯度值Manually zero the gradients after updating weights        w.grad.data.zero_()    print("progress:", epoch, l.data[0])# After trainingprint("predict (after training)",  4, forward(4).data[0]) 今天就讲到这里啦,see you next time!

  

 

转载于:https://www.cnblogs.com/liu-Deeplearning/p/10279923.html

你可能感兴趣的文章
房天下爬虫
查看>>
通过beego快速创建一个Restful风格API项目及API文档自动化(转)
查看>>
Web开发安全之文件上传安全
查看>>
mongodb常用查询语句
查看>>
JAVA-面向对象编程(上册)一、二章总结
查看>>
解决DataSnap支持的Tcp长连接数受限的两种方法
查看>>
Synchronous/Asynchronous:任务的同步异步,以及asynchronous callback异步回调
查看>>
ASP.NET MVC5 高级编程-学习日记-第二章 控制器
查看>>
在HTML中使用JavaScript(浏览器对js的加载机制分析)
查看>>
获取字符串中出现最多的字符 (HashMap()储存)
查看>>
如何选择适合自己的云管理平台(一)
查看>>
Hibernate中inverse="true"的理解
查看>>
不同版本(2.3,2.4,2.5,3.0)的Servlet web.xml 头信息
查看>>
Java的String中的subString()方法
查看>>
selenium +chrome headless Adhoc模式渲染网页
查看>>
高级滤波
查看>>
使用arcpy添加grb2数据到镶嵌数据集中
查看>>
[转载] MySQL的四种事务隔离级别
查看>>
QT文件读写
查看>>
数组去重的方法
查看>>