深入理解 LSTM:从原理到实战

afcppe 发布于 12 天前 9 次阅读



1. 为什么需要 LSTM?

在传统的循环神经网络(RNN)中,随着时间步的增加,梯度在反向传播时会出现梯度消失梯度爆炸的问题,导致网络难以学习长期依赖关系。LSTM(Long Short-Term Memory)通过引入门控机制记忆单元,有效地解决了这一问题,成为处理序列数据(如文本、语音、时间序列)的主流模型之一。


2. LSTM 的核心结构

LSTM 的核心是记忆单元(Cell State)三个门控机制

  • 遗忘门(Forget Gate):决定丢弃哪些信息。
  • 输入门(Input Gate):决定更新哪些新信息。
  • 输出门(Output Gate):决定输出哪些信息。

2.1 数学公式

设:

  • ( x_t ):当前输入
  • ( h_{t-1} ):上一时刻的隐藏状态
  • ( C_{t-1} ):上一时刻的记忆单元
  • ( \sigma ):Sigmoid 激活函数
  • ( \tanh ):双曲正切激活函数
  • ( \odot ):逐元素乘法

遗忘门

[
f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
]

输入门

[
i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \
\tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C)
]

更新记忆单元

[
C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t
]

输出门

[
o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \
h_t = o_t \odot \tanh(C_t)
]


3. 实战:用 PyTorch 实现 LSTM 文本分类

我们以 IMDB 电影评论情感分类为例,展示如何用 LSTM 实现文本分类。

3.1 环境准备

pip install torch torchvision torchtext

3.2 数据加载与预处理

import torch
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader

# 1. 加载数据
train_iter = IMDB(split='train')
tokenizer = get_tokenizer('basic_english')

# 2. 构建词汇表
def yield_tokens(data_iter):
    for label, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# 3. 文本转索引
def text_pipeline(text):
    return [vocab[token] for token in tokenizer(text)]

def label_pipeline(label):
    return 1 if label == 'pos' else 0

# 4. 构建 DataLoader
def collate_batch(batch):
    label_list, text_list, lengths = [], [], []
    for label, text in batch:
        label_list.append(label_pipeline(label))
        processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    lengths = torch.tensor(lengths)
    padded_text = torch.nn.utils.rnn.pad_sequence(text_list, batch_first=True)
    return padded_text, label_list, lengths

train_iter = IMDB(split='train')
train_loader = DataLoader(list(train_iter), batch_size=32, shuffle=True, collate_fn=collate_batch)

3.3 定义 LSTM 模型

import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, lengths):
        embedded = self.embedding(x)
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )
        _, (hidden, _) = self.lstm(packed)
        out = self.dropout(hidden[-1])
        return self.fc(out)

# 超参数
VOCAB_SIZE = len(vocab)
EMBED_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2
NUM_CLASSES = 2

model = LSTMClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, NUM_CLASSES)

3.4 训练与评估

import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(model, dataloader):
    model.train()
    total_loss = 0
    for texts, labels, lengths in dataloader:
        texts, labels, lengths = texts.to(device), labels.to(device), lengths.to(device)
        optimizer.zero_grad()
        outputs = model(texts, lengths)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

# 训练 5 个 epoch
for epoch in range(5):
    loss = train(model, train_loader)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

4. 常见问题与优化技巧

问题解决方案
过拟合使用 Dropout、L2 正则化、早停
训练慢使用 GPU、混合精度训练
长序列使用截断 BPTT、分层 LSTM
词汇表过大使用子词分词(BPE、WordPiece)

5. 总结

LSTM 通过门控机制有效解决了 RNN 的长期依赖问题,广泛应用于 NLP、时间序列预测等领域。本文从原理到实战,带你一步步实现了一个 LSTM 文本分类模型。希望对你有所帮助!


6. 参考资料

此作者没有提供个人介绍。
最后更新于 2025-10-13