当前位置: 首页 > news >正文

dynamic_rnn转nn.GRU详细记录

(原文发表在知乎专栏上,时间为2020年8月13日)

今天在将一份tensorflow的代码转为pytorch时遇到的一点困难,经过多次debug以后终于弄清楚了这里应该是如何进行转换的,因此记录下来。

直接上代码吧,为了确保最终的结果是一致的,这里我将网络层的权重全部初始化为0。

import torch
import torch.nn as nn
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializersinput = np.random.rand(3, 1, 5)
hidden = np.random.rand(3, 5)print("input: ", input.shape)
print(input)
print("hidden: ", hidden.shape)
print(hidden)print("="*20, ' tensorflow result ', "="*20)
# cell with zeros initializer
cell = tf.compat.v1.nn.rnn_cell.GRUCell(5, kernel_initializer=initializers.Zeros(), bias_initializer=initializers.Zeros())
tf_output, tf_state = tf.compat.v1.nn.dynamic_rnn(cell, input, initial_state=hidden)
print(tf_output)        # (batch size, time steps, features)
print(tf_state)         # (batch size, features) for the final time steps
print('\n')print("="*20, ' rnn cell result ', "="*20)
# rnn cell
pytorch_rnn_cell = nn.GRUCell(5, 5)
for k, v in pytorch_rnn_cell.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input_cell = torch.from_numpy(input).permute(1, 0, 2).float()   # (time steps, batch size, features)
pytorch_hidden_cell = torch.from_numpy(hidden).float()                  # (batch size, features)
pytorch_output_cell = []
for i in range(1):pytorch_hidden_cell = pytorch_rnn_cell(pytorch_input_cell[i], pytorch_hidden_cell)pytorch_output_cell.append(pytorch_hidden_cell)
print(pytorch_output_cell)
print('\n')print("="*20, ' rnn result ', "="*20)
# rnn
pytorch_rnn = nn.GRU(5, 5)
for k, v in pytorch_rnn.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input = torch.from_numpy(input).permute(1, 0, 2).float()        # (time steps, batch size, feature size)
pytorch_hidden = torch.from_numpy(hidden).unsqueeze(0).float()          # (time steps, batch size, hidden size)
pytorch_output, pytorch_state = pytorch_rnn(pytorch_input, pytorch_hidden)
print(pytorch_output, pytorch_output.shape)
print(pytorch_state, pytorch_state.shape)

最后的结果如下

input:  (3, 1, 5)
[[[0.98175333 0.59281082 0.47678967 0.70612923 0.73616147]][[0.8363702  0.85099391 0.75740424 0.30633335 0.20097122]][[0.60316062 0.21921029 0.16052985 0.25654177 0.40698399]]]
hidden:  (3, 5)
[[0.46976021 0.19681885 0.59240364 0.79540728 0.27608136][0.39461795 0.29340918 0.4515729  0.6921841  0.44068605][0.89315058 0.72514622 0.2925488  0.45433305 0.59910906]]
====================  tensorflow result  ====================
tf.Tensor(
[[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]][[0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]][[0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]]], shape=(3, 1, 5), dtype=float64)
tf.Tensor(
[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068][0.19730898 0.14670459 0.22578645 0.34609205 0.22034303][0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]], shape=(3, 5), dtype=float64)====================  rnn cell result  ====================
[tensor([[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]], grad_fn=<AddBackward0>)]====================  rnn result  ====================
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])Process finished with exit code 0
http://www.rkmt.cn/news/55283.html

相关文章:

  • 2025 最新推荐海外仓服务平台榜单:覆盖欧美东南亚等核心市场,美国 / 英国 / 德国 / 法国海外仓/换标 / 维修 / 检测优质服务商权威测评
  • Agent Dart证书验证漏洞深度解析
  • 2025年北京集团法律顾问服务权威推荐榜单:私人法律顾问/高级法律顾问/社区法律顾问服务精选
  • 2025年合肥外呼系哪家好--外呼系统推荐
  • 2025年四川搭建网站维护服务权威推荐:四川网站搭建平台/四川企业网站开发/四川企业官网搭建公司源头机构精选
  • 《浙商》杂志|协作方能共赢,湘湖论剑网易专场对接会描绘AI人机共生新蓝图
  • GESP C++ 二级真题 (2025.09) 知识点精讲
  • ESP32 C3使用ESP32-BLE-Keyboard的问题每次都要添加才能使用
  • MLGO微算法科技时空卷积与双重注意机制驱动的脑信号多任务分类算法
  • 2025耐高锰酸钠富辛环氧涂料加工厂综合评估:高性能厚浆环氧涂料涂料/乙烯基防水防腐涂/乙烯基玻璃鳞片涂料专业供应商推荐
  • 2025长沙考公面试机构测评:这5家实力最强,比较好的长沙考公面试口碑排行优选品牌推荐与解析
  • 小白也能看懂的RLHF:基础篇 - AI
  • 数据结构-线段树
  • 第十一章 泛型算法
  • 实用指南:链表-双向链表【node3】
  • 2025年复合涤纶布优质厂家权威推荐榜单:涂层涤纶布/阻燃涤纶布/防水涤纶布源头厂家精选
  • List相关知识点
  • 【山东省物联网协会主办,IEEE出版】2025年智慧物联与电子信息工程国际学术会议(IoTEIE 2025)
  • vxe-table 如何实现拖拽行数据排序,并对拖拽后进行提示框二次确认是否允许拖拽
  • SOLID原则在React中的应用实践
  • 绘图工具
  • 2025 年 11 月离心机厂家推荐排行榜,台式低速大容量离心机,血液离心机,台式低速离心机,台式指针式离心机,台式离心机,小高速离心机,低速微电脑控制离心机,六乘五十毫升离心机,高速离心机公司推荐
  • 深入解析:BERT,GPT,ELMO模型对比
  • 2025年颗粒活性炭订制厂家权威推荐榜单:活性炭过滤/煤质活性炭/粉末活性炭源头厂家精选
  • 已有ERP和MES,为什么还需要质量管理系统(QMS)?
  • SBD3D60V1H-ASEMI可直接替代安世PMEG6010CEJ
  • 重庆一对一辅导机构精选推荐,2025合规家教机构口碑排名已公布,附师资实力测评
  • 2025 年 11 月开关柜供应厂家推荐排行榜,高压开关柜,低压开关柜,配电开关柜,智能开关柜公司推荐
  • 重庆一对一家教机构口碑推荐,2025辅导机构最新排名出炉,带详细选课攻略
  • 2025 年 11 月轴承厂家推荐排行榜,瓦房店轴承,深沟球轴承,调心滚子轴承,圆锥滚子轴承源头厂家实力解析与选购指南