1. 为什么需要 LSTM?
在传统的循环神经网络(RNN)中,随着时间步的增加,梯度在反向传播时会出现梯度消失或梯度爆炸的问题,导致网络难以学习长期依赖关系。LSTM(Long Short-Term Memory)通过引入门控机制和记忆单元,有效地解决了这一问题,成为处理序列数据(如文本、语音、时间序列)的主流模型之一。
2. LSTM 的核心结构
LSTM 的核心是记忆单元(Cell State)和三个门控机制:
- 遗忘门(Forget Gate):决定丢弃哪些信息。
- 输入门(Input Gate):决定更新哪些新信息。
- 输出门(Output Gate):决定输出哪些信息。
2.1 数学公式
设:
- ( x_t ):当前输入
- ( h_{t-1} ):上一时刻的隐藏状态
- ( C_{t-1} ):上一时刻的记忆单元
- ( \sigma ):Sigmoid 激活函数
- ( \tanh ):双曲正切激活函数
- ( \odot ):逐元素乘法
遗忘门:
[
f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
]
输入门:
[
i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \
\tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C)
]
更新记忆单元:
[
C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t
]
输出门:
[
o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \
h_t = o_t \odot \tanh(C_t)
]
3. 实战:用 PyTorch 实现 LSTM 文本分类
我们以 IMDB 电影评论情感分类为例,展示如何用 LSTM 实现文本分类。
3.1 环境准备
pip install torch torchvision torchtext
3.2 数据加载与预处理
import torch
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
# 1. 加载数据
train_iter = IMDB(split='train')
tokenizer = get_tokenizer('basic_english')
# 2. 构建词汇表
def yield_tokens(data_iter):
for label, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
# 3. 文本转索引
def text_pipeline(text):
return [vocab[token] for token in tokenizer(text)]
def label_pipeline(label):
return 1 if label == 'pos' else 0
# 4. 构建 DataLoader
def collate_batch(batch):
label_list, text_list, lengths = [], [], []
for label, text in batch:
label_list.append(label_pipeline(label))
processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64)
text_list.append(processed_text)
lengths.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
lengths = torch.tensor(lengths)
padded_text = torch.nn.utils.rnn.pad_sequence(text_list, batch_first=True)
return padded_text, label_list, lengths
train_iter = IMDB(split='train')
train_loader = DataLoader(list(train_iter), batch_size=32, shuffle=True, collate_fn=collate_batch)
3.3 定义 LSTM 模型
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_classes)
self.dropout = nn.Dropout(0.5)
def forward(self, x, lengths):
embedded = self.embedding(x)
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths, batch_first=True, enforce_sorted=False
)
_, (hidden, _) = self.lstm(packed)
out = self.dropout(hidden[-1])
return self.fc(out)
# 超参数
VOCAB_SIZE = len(vocab)
EMBED_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2
NUM_CLASSES = 2
model = LSTMClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, NUM_CLASSES)
3.4 训练与评估
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def train(model, dataloader):
model.train()
total_loss = 0
for texts, labels, lengths in dataloader:
texts, labels, lengths = texts.to(device), labels.to(device), lengths.to(device)
optimizer.zero_grad()
outputs = model(texts, lengths)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# 训练 5 个 epoch
for epoch in range(5):
loss = train(model, train_loader)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
4. 常见问题与优化技巧
| 问题 | 解决方案 |
|---|---|
| 过拟合 | 使用 Dropout、L2 正则化、早停 |
| 训练慢 | 使用 GPU、混合精度训练 |
| 长序列 | 使用截断 BPTT、分层 LSTM |
| 词汇表过大 | 使用子词分词(BPE、WordPiece) |
5. 总结
LSTM 通过门控机制有效解决了 RNN 的长期依赖问题,广泛应用于 NLP、时间序列预测等领域。本文从原理到实战,带你一步步实现了一个 LSTM 文本分类模型。希望对你有所帮助!
Comments NOTHING