当前位置: 首页 > news >正文

用循环神经网络生成0^n 1^n形式的简单序列

题目详细如下:

微信图片_20251211231041_46_13

源代码
import torch
import torch.nn as nn
import torch.optim as optim
import random#数据准备与预处理
def generate_data(num_samples, max_n=5):data = []for _ in range(num_samples):n = random.randint(1, max_n)seq = '0' * n + '1' * n data.append(seq)return datadef encode_seq(seq, max_len):encoded = [0 if c == '0' else 1 for c in seq]padded = encoded + [2] * (max_len - len(encoded))  #2为填充符return torch.tensor(padded, dtype=torch.long)max_n = 5
max_seq_len = 2 * max_n
train_data = generate_data(1000, max_n)#定义LSTM模型
class SeqGenerate(nn.Module):def __init__(self, vocab_size=3, embed_dim=8, hidden_dim=16):super().__init__()self.embedding = nn.Embedding(vocab_size, embed_dim)self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, hidden=None):# x形状:(batch_size, seq_len)x = self.embedding(x)x, hidden = self.lstm(x, hidden)logits = self.fc(x)return logits, hidden#训练模型
model = SeqGenerate()
criterion = nn.CrossEntropyLoss(ignore_index=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
#训练循环
epochs = 50
batch_size = 32
for epoch in range(epochs):total_loss = 0random.shuffle(train_data)for i in range(0, len(train_data), batch_size):batch = train_data[i: (i+batch_size)]inputs = [encode_seq(seq[:-1], max_seq_len-1) for seq in batch]labels = [encode_seq(seq[1:], max_seq_len-1) for seq in batch]inputs = torch.stack(inputs)labels = torch.stack(labels)optimizer.zero_grad()logits, _ = model(inputs)loss = criterion(logits.reshape(-1, 3), labels.reshape(-1))loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}, Loss:{total_loss / len(train_data):.4f}")#生成 0n 1n序列
def generate_seq(model, start_char='0', target_n=3):model.eval()seq = [start_char]hidden = Nonefor _ in range(target_n - 1):input_tensor = torch.tensor([[0]], dtype=torch.long)logits, hidden = model(input_tensor, hidden)next_char = torch.argmax(logits, dim=-1).item()seq.append('0' if next_char == 0 else '1')for _ in range(target_n):nput_tensor = torch.tensor([[1]], dtype=torch.long)logits, hidden = model(input_tensor, hidden)next_char = torch.argmax(logits, dim=-1).item()seq.append('0' if next_char == 0 else '1')return ''.join(seq)generate_seq = generate_seq(model, target_n=3)
print(f"生成的0^n 1^n序列:{generate_seq}")

输出结果:

屏幕截图 2025-12-11 230755

http://www.rkmt.cn/news/83704.html

相关文章:

  • AcWing 846:树的重心 ← 链式前向星 or 邻接表
  • 251211
  • Python自然语言处理的未来:技术栈与开发范式
  • 观察者模式
  • 2025年东莞优质的铝门窗批发选哪家,安全门窗/铝门窗/慕莎尼奥门窗/窗纱一体铝门窗/门窗/铝门窗品牌选哪家 - 品牌推荐师
  • 2025.12.11总结
  • 124_尚硅谷_闭包的基本介绍
  • One Year XTOOL D9S Update Service: Keep Diagnostics Up-to-Date for EU US Vehicles
  • 2025年数控车床品牌新格局,机械手集成能力排行揭晓,动力刀塔数控车/牙科配件数控车床/新能源数控车床/军工配件数控机床数控车床设计怎么选择 - 品牌推荐师
  • 如何确定arm固件的加载地址
  • 2025年国内靠谱的门窗源头厂家推荐,全屋门窗/环保门窗/复古门窗/极简门窗/欧式门窗/智能门窗/门窗直销厂家找哪家 - 品牌推荐师
  • 基于协同过滤推荐算法的求职招聘推荐系统u1ydn3f4(程序、源码、数据库、调试部署优秀的方案及开发环境)系统界面展示及获取方式置于文档末尾,可供参考。
  • 12.11笔记
  • 中国人工智能学会推荐国际学术会议和国际/国内期刊目录
  • 蓝桥杯-Python-题目整理2
  • 喵喵喵 XI
  • 深度学习方法在语音识别中的全面解析
  • 详解Adobe Experience Manager存储型XSS漏洞CVE-2025-64829
  • 中国自动化学会推荐学术会议、科技期刊目录(2024)发布
  • 国内直连?API源头供应?深度实测GrsAI的Sora2接口0.08/条视频它真的靠谱吗?
  • 在 Steam Deck 上開啓用戶級別的 SMB
  • 如何在 Steam Deck 上備份截圖
  • 【AI】前置篇 Ai Agent的全貌概览
  • 考陪诊师为什么选北京守嘉陪诊报名? - 品牌排行榜单
  • 【torch】torch.cat和直接相加的区别
  • 《综合项目实战-局域网内的沟通软件》
  • Java基础补缺2
  • Ai元人文:对岐金兰观察的深度回应——价值协商与数值优化的范式调和
  • 12/11
  • 深入解析:[特殊字符] 在 Windows 上设置 SQLite