尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

dynamic_rnn转nn.GRU详细记录

dynamic_rnn转nn.GRU详细记录
📅 发布时间:2026/6/21 18:56:22

(原文发表在知乎专栏上,时间为2020年8月13日)

今天在将一份tensorflow的代码转为pytorch时遇到的一点困难,经过多次debug以后终于弄清楚了这里应该是如何进行转换的,因此记录下来。

直接上代码吧,为了确保最终的结果是一致的,这里我将网络层的权重全部初始化为0。

import torch
import torch.nn as nn
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializersinput = np.random.rand(3, 1, 5)
hidden = np.random.rand(3, 5)print("input: ", input.shape)
print(input)
print("hidden: ", hidden.shape)
print(hidden)print("="*20, ' tensorflow result ', "="*20)
# cell with zeros initializer
cell = tf.compat.v1.nn.rnn_cell.GRUCell(5, kernel_initializer=initializers.Zeros(), bias_initializer=initializers.Zeros())
tf_output, tf_state = tf.compat.v1.nn.dynamic_rnn(cell, input, initial_state=hidden)
print(tf_output)        # (batch size, time steps, features)
print(tf_state)         # (batch size, features) for the final time steps
print('\n')print("="*20, ' rnn cell result ', "="*20)
# rnn cell
pytorch_rnn_cell = nn.GRUCell(5, 5)
for k, v in pytorch_rnn_cell.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input_cell = torch.from_numpy(input).permute(1, 0, 2).float()   # (time steps, batch size, features)
pytorch_hidden_cell = torch.from_numpy(hidden).float()                  # (batch size, features)
pytorch_output_cell = []
for i in range(1):pytorch_hidden_cell = pytorch_rnn_cell(pytorch_input_cell[i], pytorch_hidden_cell)pytorch_output_cell.append(pytorch_hidden_cell)
print(pytorch_output_cell)
print('\n')print("="*20, ' rnn result ', "="*20)
# rnn
pytorch_rnn = nn.GRU(5, 5)
for k, v in pytorch_rnn.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input = torch.from_numpy(input).permute(1, 0, 2).float()        # (time steps, batch size, feature size)
pytorch_hidden = torch.from_numpy(hidden).unsqueeze(0).float()          # (time steps, batch size, hidden size)
pytorch_output, pytorch_state = pytorch_rnn(pytorch_input, pytorch_hidden)
print(pytorch_output, pytorch_output.shape)
print(pytorch_state, pytorch_state.shape)

最后的结果如下

input:  (3, 1, 5)
[[[0.98175333 0.59281082 0.47678967 0.70612923 0.73616147]][[0.8363702  0.85099391 0.75740424 0.30633335 0.20097122]][[0.60316062 0.21921029 0.16052985 0.25654177 0.40698399]]]
hidden:  (3, 5)
[[0.46976021 0.19681885 0.59240364 0.79540728 0.27608136][0.39461795 0.29340918 0.4515729  0.6921841  0.44068605][0.89315058 0.72514622 0.2925488  0.45433305 0.59910906]]
====================  tensorflow result  ====================
tf.Tensor(
[[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]][[0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]][[0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]]], shape=(3, 1, 5), dtype=float64)
tf.Tensor(
[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068][0.19730898 0.14670459 0.22578645 0.34609205 0.22034303][0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]], shape=(3, 5), dtype=float64)====================  rnn cell result  ====================
[tensor([[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]], grad_fn=<AddBackward0>)]====================  rnn result  ====================
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])Process finished with exit code 0

相关新闻

  • 2025 最新推荐海外仓服务平台榜单:覆盖欧美东南亚等核心市场,美国 / 英国 / 德国 / 法国海外仓/换标 / 维修 / 检测优质服务商权威测评
  • Agent Dart证书验证漏洞深度解析
  • 2025年北京集团法律顾问服务权威推荐榜单:私人法律顾问/高级法律顾问/社区法律顾问服务精选

最新新闻

  • i.MX 6SoloX引脚配置全解析:从BGA封装到PCB设计实战
  • VisualCppRedist AIO:终极VC++运行库一站式解决方案完全指南
  • 小众纯粮白酒推荐排行:2026纯粮好酒榜单,喝出地道粮食香 - 速递信息
  • 2026 抖音电商密文面单合规指南:一件代发下单、发货售后与违规检测全套实操 - 速递信息
  • 第16章:Ollama服务化架构——从本地工具到团队服务
  • 电动车托运1000公里多少钱?2026最新价格与省钱攻略 - 快递物流资讯

日新闻

  • Visual C++运行库修复终极指南:5分钟快速解决Windows软件启动错误
  • 手把手教你构建统计局地区经济数据爬虫:从环境搭建到数据持久化全指南
  • 2026多Agent深度解析:用AI团队替代单一模型,四种架构实战落地

周新闻

  • Visual C++运行库修复终极指南:5分钟快速解决Windows软件启动错误
  • 手把手教你构建统计局地区经济数据爬虫:从环境搭建到数据持久化全指南
  • 2026多Agent深度解析:用AI团队替代单一模型,四种架构实战落地

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号