当前位置: 首页 > news >正文

第三章 训练初步深入(3)

# 这版代码是将数据集的读取和处理分离,同时将文本和标签构建字典
# 字典的构建方法是将所有文本中的词汇和标签都加入字典,并给每个词汇和标签分配一个索引# max_len是设置的最大长度,超过这个长度的文本将被截断
# 如果文本长度小于max_len,则用0填充
# 导入numpy库,并将处理过的文本转换为矩阵后输出# 以下是原代码import os
import math
import random
import numpy as np#导入numpy库def read_file(filepath):with open(filepath, "r", encoding="utf-8") as f:all_lines = f.read().split("\n")# print(all_lines)all_text = []all_label = []for line in all_lines:# print(line)data_s = line.split()if len(data_s) != 2:continueelse:text, label = data_sall_label.append(label)all_text.append(text)assert len(all_text) == len(all_label), "text and label length not equal"return all_text, all_label# Dataset 是所有数据集的集合,DataLoader 是每次返回一个batch的迭代器
class Dataset:def __init__(self, all_text, all_label, batch_size, word_2_dict, label_2_dict):self.all_text = all_textself.all_label = all_labelself.batch_size = batch_sizeself.word_2_dict = word_2_dictself.label_2_dict = label_2_dictdef __iter__(self):dataloader = DataLoader(self)return dataloaderdef __getitem__(self, index):# 实现__getitem__方法,返回一个batch的数据text = self.all_text[index][:max_len]#截断文本label = self.all_label[index]#获取标签text_idx = [self.word_2_dict[w] for w in text]# 将文本中的词汇转换为索引label_idx = self.label_2_dict[label]#将标签转换为索引text_idx_p =text_idx+[0]*(max_len - len(text_idx))#用0填充文本索引,使其长度为max_lenreturn text_idx_p, label_idx#返回文本索引和标签索引class DataLoader:def __init__(self, dataset):self.dataset = datasetself.cursor = 0def __next__(self):if self.cursor >= len(self.dataset.all_text):raise StopIterationbatch_data = [self.dataset[i]for i in range(self.cursor,min(self.cursor + self.dataset.batch_size, len(self.dataset.all_text)),)]if batch_data:text_idx, label_idx = zip(*batch_data)else:raise StopIterationself.cursor += self.dataset.batch_sizereturn np.array(text_idx), np.array(label_idx)#np.array()是将列表转换为numpy数组,shape=(batch_size,max_len)def build_word_2_dict(all_text):#构建词汇到索引的字典word_2_dict = {"PAD": 0}#PAD表示padding,索引为0for text in all_text:#遍历所有文本for w in text:#遍历每一个文本中的词word_2_dict[w] = word_2_dict.get(w, len(word_2_dict))#get(w,len(word_2_dict))表示如果w在字典中存在,则返回w的索引,否则返回len(word_2_dict)return word_2_dict
#返回词汇到索引的字典def build_label_2_dict(all_label):#构建标签到索引的字典return {k: i for i, k in enumerate(set(all_label), start=0)}
#返回标签到索引的字典,set(all_label)的元素是不重复的,enumerate(set(all_label),start=0)返回一个字典,key是元素,value是从0开始的索引if __name__ == "__main__":filepath = os.path.join("D:/", "my code", "Python", "NLP basic", "data", "train2.txt")all_text, all_label = read_file(filepath)epoch = 1bitch_size = 6max_len = 20 # 设置最大长度,超过这个长度的文本将被截断word_2_dict = build_word_2_dict(all_text)label_2_dict = build_label_2_dict(all_label)# print(word_2_dict)# print(label_2_dict)train_dataset = Dataset(all_text, all_label, bitch_size, word_2_dict, label_2_dict)for e in range(epoch):train_dataset.cursor = 0print("Epoch:", e + 1, "/", epoch)for data in train_dataset:batch_text_idx, batch_label_idx = dataprint(batch_text_idx)print(batch_label_idx)

image

http://www.aitangshan.cn/news/478.html

相关文章:

  • 安装pandas
  • 奥林匹克小丛书小蓝本习题另解或加强(数论卷)(一)
  • 关于磁盘io性能的命令
  • 房屋防水是建筑工程中非常重要的一部分,通常需要根据不同的环境、建筑结构和使用需求来采取相应的防水措施。国家标准对防水工程的要求有详细规定,以下是常见的防水相关国家标准和要求:
  • Hulo 编程语言开发 —— 从源代码到 AST 的魔法转换
  • python中enumerate的作用
  • ly-容斥杂题选讲
  • 前向传播 反向传播
  • Attention 显存计算 推理训练复杂度
  • NLP随记
  • RL 随记
  • top命令详解
  • 2025杭电暑期(8) 最努力的活着 推式子
  • 从输入网址到看到页面:一段看不见的旅程
  • 牛客周赛109补题
  • stress命令详解
  • Nvidia Proprietary GPU Drivers
  • dd命令生成文件详解
  • 关于PVC排水管系统中存水弯设计的常见类型分类表格:
  • 一个好点子,但是我克制住了
  • 软考系统分析师每日学习卡 | [日期:2025-08-11] | [今日主题:数据库设计过程-概念结构设计阶段]
  • 2025年8月11日
  • 基于AOA算术优化的KNN数据聚类算法matlab仿真
  • strace命令
  • 基于最优转子磁链混合效率优化控制和铁损补偿的PMSM控制系统simulink建模与仿真
  • python中raise的用法
  • alias命令
  • 口播
  • nmap命令
  • CSP-J/S 2024 游记