Text Matching with LSTM in PyTorch

Text matching aims to determine whether two input sequences are semantical related or similar. This is commonly used in applications like question answering, duplicate dteection, and information retrieval.

A typical approach involves encoding each sentence independently using recurrent neural networks such as LSTM, then comparing their final representations to produce a similarity score or classification label.

Model Architecture

  • Input: Two sequences of token indices.
  • Embedding Layer: Converts tokens into dense vectors.
  • LSTM Encoder: Processes each sequence to generate contextualized representations. Bidirectional variants can be used for richer context.
  • Similarity Function: Computes a similarity metric (e.g., cosine similarity) between the encoded vectors.
  • Classifier: Maps the similarity score to a binary or multi-class output.

Implementation

import torch
import torch.nn as nn

class TextMatchingLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_classes):
        super().__init__()
        self.embed = nn.Embedding(vocab_size + 1, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=False)
        self.similarity = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.classifier = nn.Linear(1, num_classes)

    def forward(self, seq_a, seq_b):
        # Embed both sequences
        emb_a = self.embed(seq_a)
        emb_b = self.embed(seq_b)

        # Encode with shared LSTM
        out_a, _ = self.lstm(emb_a)
        out_b, _ = self.lstm(emb_b)

        # Use last timestep output
        rep_a = out_a[-1]
        rep_b = out_b[-1]

        # Compute similarity
        sim = self.similarity(rep_a, rep_b).unsqueeze(1)

        # Classify
        logits = self.classifier(sim)
        return logits

Training Loop

model = TextMatchingLSTM(
    vocab_size=vocab_size,
    embed_dim=embedding_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    num_classes=output_dim
)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

best_acc = 0.0
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_correct = 0
    total_samples = 0

    for batch in train_loader:
        sent_a, sent_b, labels = batch
        sent_a = sent_a.long().to(device).transpose(0, 1)
        sent_b = sent_b.long().to(device).transpose(0, 1)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(sent_a, sent_b)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    epoch_acc = total_correct / total_samples
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), 'best_model.pth')

Inference Example

sentence_a = "我不爱吃剁椒鱼头,但是我爱吃鱼头"
sentence_b = "我爱吃土豆,但是不爱吃地瓜"

# Tokenize and convert to indices
seq_a = torch.tensor([[word2idx.get(w, 0) for w in sentence_a]])
seq_b = torch.tensor([[word2idx.get(w, 0) for w in sentence_b]])

# Pad and transpose
seq_a = pad_and_transpose(seq_a, max_len=20)
seq_b = pad_and_transpose(seq_b, max_len=20)

# Load model and predict
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()
with torch.no_grad():
    logits = model(seq_a, seq_b)
    pred = logits.argmax().item()

label_map = {0: "Not Matched", 1: "Matched"}
print(f"Prediction: {label_map[pred]}")

Tags: pytorch LSTM Text Matching Natural Language Processing Deep Learning

Posted on Thu, 04 Jun 2026 17:23:38 +0000 by ggseven