LSTM文本预测(Pytorch版)

任务:基于 flare 文本数据,建立 LSTM 模型,预测序列文字
1.完成数据预处理,将文字序列数据转化为可用于LSTM输入的数据
2.查看文字数据预处理后的数据结构,并进行数据分离操作
3.针对字符串输入(" flare is a teacher in ai industry. He obtained his phd in Australia."),预测其对应的后续字符
参考视频:吹爆!3小时搞懂!【RNN循环神经网络+时间序列LSTM深度学习模型】学不会UP主下跪!
部分参数与视频不同

pre.py

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from model import LSTM

# 加载数据
data = open('flare').read()
# 移除换行符
data = data.replace('\n','').replace('\r','')
# print(data)
# 字符去重
letters = list(set(data))
num_letters = len(letters)
# print(letters)
# print(len(letters))

# 建立字典
int_to_char = {a:b for a,b in enumerate(letters)}
# print(int_to_char)
char_to_int = {b:a for a,b in enumerate(letters)}
# print(char_to_int)
time_step = 10

# 滑动窗口提取数据
def extract_data(data, slide):
  x = []
  y = []
  for i in range(len(data) - slide):
    x.append([a for a in data[i : i + slide]])
    y.append(data[i+slide])
  return x,y

# 字符到数字的批量转化
def char_to_int_Data(x, y, chat_to_int):
  x_to_int = []
  y_to_int = []
  for i in range(len(x)):
    x_to_int.append([char_to_int[char] for char in x[i]])
    y_to_int.append([char_to_int[char] for char in y[i]])  
  return x_to_int, y_to_int

# 实现输入字符文章的批量处理,输入整个字符,滑动窗口大小,转化字典
def data_preprocessing(data, slide, num_letters, char_to_int):
  char_Data = extract_data(data, slide)  
  int_Data = char_to_int_Data(char_Data[0], char_Data[1], char_to_int)  
  Input = int_Data[0]
  Output = list(np.array(int_Data[1]).flatten())
  Input_RESHAPED = np.array(Input).reshape(len(Input), slide)
  new = np.random.randint(0, 10, size=[Input_RESHAPED.shape[0], Input_RESHAPED.shape[1], num_letters])  
  for i in range(Input_RESHAPED.shape[0]):
    for j in range(Input_RESHAPED.shape[1]):
      new[i, j, :] = torch.nn.functional.one_hot(torch.tensor(Input_RESHAPED[i, j], dtype=torch.long), num_classes = num_letters)  
  return new, Output
x,y = data_preprocessing(data, time_step, num_letters, char_to_int)
# print(y)

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=10)
# print(x_train.shape, len(y_train))
y_train_category = torch.nn.functional.one_hot(torch.tensor(y_train, dtype=torch.long), num_letters)
# print(y_train_category)

# 将数据转换为 PyTorch 的 Tensor
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

model.py

import torch
from torch import nn

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, dropout_prob=0.2):
        super(LSTM, self).__init__()
        
        # 定义LSTM层
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout_prob)
        
        # 定义Dropout层
        self.dropout = nn.Dropout(dropout_prob)  # Dropout层,用于在全连接层前丢弃部分神经元
        
        # 定义全连接层
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        # LSTM输出
        out, _ = self.lstm(x)
        
        # LSTM输出的最后一个时间步
        out = out[:, -1, :]
        
        # Dropout层
        out = self.dropout(out)
        
        # 全连接层输出
        out = self.fc(out)
        
        return out

train.py

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from model import LSTM
from pre import *

# 定义模型参数
input_size = num_letters  # 输入大小等于字母集的大小
hidden_size = 256         # 隐藏层大小
output_size = num_letters # 输出大小(预测下一个字符)
num_layers = 2            # LSTM层数

# 实例化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTM(input_size, hidden_size, output_size, num_layers).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(reduction = 'mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 创建 DataLoader
train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 训练模型
num_epochs = 10
best_accuracy = 0.0  # 用于保存最好的模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
    
    # 你可以在每个 epoch 后验证模型并保存最佳模型
    model.eval()
    with torch.no_grad():
        x_test_tensor = torch.tensor(x_test, dtype=torch.float32).to(device)
        y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)
        
        outputs = model(x_test_tensor)
        _, predicted = torch.max(outputs, dim=1)
        correct = (predicted == y_test_tensor).sum().item()
        accuracy = correct / y_test_tensor.size(0)

        print(f'Epoch [{epoch+1}/{num_epochs}], Test Accuracy: {accuracy * 100:.2f}%')

        # 如果模型的准确率提升了,则保存模型
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), 'best_lstm_model.pth')
            print("Model saved!")

# 最后保存最终模型
torch.save(model.state_dict(), 'final_lstm_model.pth')

# # 测试模型
# model.eval()
# with torch.no_grad():
#     x_test_tensor = torch.tensor(x_test, dtype=torch.float32).to(device)  # 确保测试数据在设备上
#     y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)     # 确保测试标签在设备上
    
#     # 前向传播
#     outputs = model(x_test_tensor)
#     _, predicted = torch.max(outputs, dim=1)  # 获取预测类别的索引
    
#     # 计算准确率
#     correct = (predicted == y_test_tensor).sum().item()
#     accuracy = correct / y_test_tensor.size(0)
#     print(f'Test Accuracy: {accuracy * 100:.2f}%')

test.py

import torch
from model import LSTM
from pre import *  # 确保 'pre' 模块中包含了数据处理的相关代码
from sklearn.metrics import accuracy_score

# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义模型参数(与训练时的参数一致)
input_size = num_letters  # 输入大小等于字母表的大小
hidden_size = 256         # 隐藏层大小
output_size = num_letters # 输出大小(预测下一个字符)
num_layers = 2           # LSTM层数

# 实例化模型并加载训练好的参数
model = LSTM(input_size, hidden_size, output_size, num_layers).to(device)
model.load_state_dict(torch.load('best_lstm_model.pth'))  # 加载你保存的最佳模型
model.eval()  # 设置为评估模式

# 需要预测的新的字符串
new_string = "flare is a teacher in ai industry. He obtained his phd in Australia."

# 预处理输入数据:将新字符串转换为适合模型输入的张量形式
X_new, y_new = data_preprocessing(new_string, time_step, num_letters, char_to_int)  # 使用相同的预处理函数
X_new_tensor = torch.tensor(X_new, dtype=torch.float32).to(device)
y_new_tensor = torch.tensor(y_new, dtype=torch.long).to(device)  # 实际的标签

# 进行预测
with torch.no_grad():
    # 前向传播,获取模型的输出
    outputs = model(X_new_tensor)
    _, predicted_indices = torch.max(outputs, dim=1)  # 获取每个时间步的预测类别

# 将预测的索引转换回字符
predicted_chars = [int_to_char[idx.item()] for idx in predicted_indices]

# 将真实的标签转换回字符
true_chars = [int_to_char[idx] for idx in y_new]

# 计算准确率
correct_predictions = (predicted_indices == y_new_tensor).sum().item()
total_predictions = len(y_new_tensor)
accuracy = correct_predictions / total_predictions

# 打印预测结果与准确率
print(f"Accuracy on new string: {accuracy * 100:.2f}%")

# 打印详细的预测信息
for i in range(len(new_string) - time_step):
    print(f"Context: {new_string[i:i + time_step]} --> Predicted: {predicted_chars[i]}, Actual: {true_chars[i]}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/881405.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

opencv图像透视处理

引言 在图像处理与计算机视觉领域,透视变换(Perspective Transformation)是一种重要的图像校正技术,它允许我们根据图像中已知的四个点(通常是矩形的四个角)和目标位置的四个点,将图像从一个视…

软件安装攻略:EmEditor编辑器下载安装与使用

EmEditor是一款在Windows平台上运行的文字编辑程序。EmEditor以运作轻巧、敏捷而又功能强大、丰富著称,得到许多用户的好评。Windows内建的记事本程式由于功能太过单薄,所以有不少用户直接以EmEditor取代,emeditor是一个跨平台的文本编辑器&a…

聊城网站建设:企业如何打造高效官网

聊城网站建设:企业如何打造高效官网 在互联网飞速发展的今天,官方网站已成为企业展示形象、推广产品、与客户沟通的重要平台。尤其对于聊城地区的企业来说,建立一个高效的官网显得尤为重要。本文将分享一些关键步骤,帮助企业打造一…

MapReduce基本原理

目录 整体执行流程​ Map端执行流程 Reduce端执行流程 Shuffle执行流程 整体执行流程 八部曲 读取数据--> 定义map --> 分区 --> 排序 --> 规约 --> 分组 --> 定义reduce --> 输出数据 首先将文件进行切片(block)处理&#xff…

人工智能快速发展下的极端风险管理

文章目录 前言一、快速进步与高风险并存1、深度学习系统缺乏关键功能,其开发周期尚不明朗2、自主人工智能系统一旦导向不良目标,人类可能面临其失控风险 二、技术研发方向调整1、实现安全人工智能的基础性突破,确保人工智能可靠安全2、实现有…

shopro前端 短信登录只显示模板不能正常切换

删掉 换成下面的代码 // 打开授权弹框 export function showAuthModal(type smsLogin) {const modal $store(modal);setTimeout(() > {modal.$patch((state) > {state.auth type;});}, 100); }

Python酷库之旅-第三方库Pandas(123)

目录 一、用法精讲 546、pandas.DataFrame.ffill方法 546-1、语法 546-2、参数 546-3、功能 546-4、返回值 546-5、说明 546-6、用法 546-6-1、数据准备 546-6-2、代码示例 546-6-3、结果输出 547、pandas.DataFrame.fillna方法 547-1、语法 547-2、参数 547-3、…

AI+教育|拥抱AI智能科技,让课堂更生动高效

AI在教育领域的应用正逐渐成为现实,提供互动性强的学习体验,正在改变传统教育模式。AI不仅改变了传统的教学模式,还为教育提供了更多的可能性和解决方案。从个性化学习体验到自动化管理任务,AI正在全方位提升教育质量和效率。随着…

【OJ刷题】双指针问题6

这里是阿川的博客,祝您变得更强 ✨ 个人主页:在线OJ的阿川 💖文章专栏:OJ刷题入门到进阶 🌏代码仓库: 写在开头 现在您看到的是我的结论或想法,但在这背后凝结了大量的思考、经验和讨论 目录 1…

技术周总结 09.09~09.15周日(C# WinForm WPF 软件架构)

文章目录 一、09.09 周一1.1) 问题01: Windows桌面开发中,WPF和WinForm的区别和联系?联系:区别: 二、09.12 周四2.1)问题01:visual studio的相关快捷键有哪些?通用快捷键编辑导航调试窗口管理 2…

Python Selenium 自动化爬虫 + Charles Proxy 抓包

一、场景介绍 我们平常会遇到一些需要根据省、市、区查询信息的网站。 1、省市查询 比如这种,因为全国的省市比较多,手动查询工作量还是不小。 2、接口签名 有时候我们用python直接查询后台接口的话,会发现接口是加签名的。 而签名算法我…

细胞分裂检测系统源码分享

细胞分裂检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

计算机人工智能前沿进展-大语言模型方向-2024-09-20

计算机人工智能前沿进展-大语言模型方向-2024-09-20 1. Multimodal Fusion with LLMs for Engagement Prediction in Natural Conversation Authors: Cheng Charles Ma, Kevin Hyekang Joo, Alexandria K. Vail, Sunreeta Bhattacharya, Alvaro Fern’andez Garc’ia, Kailan…

[数据集][目标检测]智慧交通铁轨裂缝检测数据集VOC+YOLO格式4类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2709 标注数量(xml文件个数):2709 标注数量(txt文件个数):2709 标注…

独立站技能树/工具箱1.0 总纲篇丨出海笔记

正所谓要把一件事做到90分很难,但做到60分基本上照着SOP做到位都没问题,如果我们能把每件事都做到60分,那绝对比至少60%的人都强,除非你的对手不讲武德——那就是他很可能看了我这篇文章,不但每方面都超过及格线&#…

fiddler抓包06_抓取https请求(chrome)

课程大纲 首次安装Fiddler,抓https请求,除打开抓包功能(F12)还需要: ① Fiddler开启https抓包 ② Fiddler导出证书; ③ 浏览器导入证书。 否则,无法访问https网站(如下图&#xff0…

将sqlite3移植到arm开发板上:

一、下载源代码 sqlite3网址:https://www.sqlite.org/download.html 下载:sqlite-autoconf-3460100.tar.gz 二、解压 在Linux家目录下创建一个sqlite3文件夹,将压缩包复制到该文件夹下,再在该目录下打开一个终端,执行…

【Linux】简易日志系统

目录 一、概念 二、可变参数 三、日志系统 一、概念 一个正在运行的程序或系统就像一个哑巴,一旦开始运行我们很难知晓其内部的运行状态。 但有时在程序运行过程中,我们想知道其内部不同时刻的运行结果如何,这时一个日志系统可以有效的帮…

【路径规划】 红嘴蓝鹊优化器:一种用于2D/3D无人机路径规划和工程设计问题的新型元启发式算法

摘要 本文提出了一种新型元启发式算法——红嘴蓝鹊优化器(RBMO),用于解决2D和3D无人机路径规划以及复杂工程设计问题。RBMO灵感来源于红嘴蓝鹊的群体合作行为,包括搜索、追逐、捕猎和食物储藏。该算法通过模拟这些行为&#xff0…

prober found high clock drift,Linux服务器时间不能自动同步,导致服务器时间漂移解决办法。

文章目录 一、场景二、问题三、解决办法(一)给服务器添加访问网络能力(二)手动同步1. 检查有没有安装ntp2. 没有安装ntp则离线安装ntp2.1 下载安装包2.2 安装2.3 启动 ntp 3. 设置内部时钟源3.1 编辑/etc/ntp.conf3.1 重启ntp服务…