Text matching aims to determine whether two input sequences are semantical related or similar. This is commonly used in applications like question answering, duplicate dteection, and information retrieval.
A typical approach involves encoding each sentence independently using recurrent neural networks such as LSTM, then comparing their final representations to produce a similarity score or classification label.
Model Architecture
- Input: Two sequences of token indices.
- Embedding Layer: Converts tokens into dense vectors.
- LSTM Encoder: Processes each sequence to generate contextualized representations. Bidirectional variants can be used for richer context.
- Similarity Function: Computes a similarity metric (e.g., cosine similarity) between the encoded vectors.
- Classifier: Maps the similarity score to a binary or multi-class output.
Implementation
import torch
import torch.nn as nn
class TextMatchingLSTM(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_classes):
super().__init__()
self.embed = nn.Embedding(vocab_size + 1, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=False)
self.similarity = nn.CosineSimilarity(dim=1, eps=1e-6)
self.classifier = nn.Linear(1, num_classes)
def forward(self, seq_a, seq_b):
# Embed both sequences
emb_a = self.embed(seq_a)
emb_b = self.embed(seq_b)
# Encode with shared LSTM
out_a, _ = self.lstm(emb_a)
out_b, _ = self.lstm(emb_b)
# Use last timestep output
rep_a = out_a[-1]
rep_b = out_b[-1]
# Compute similarity
sim = self.similarity(rep_a, rep_b).unsqueeze(1)
# Classify
logits = self.classifier(sim)
return logits
Training Loop
model = TextMatchingLSTM(
vocab_size=vocab_size,
embed_dim=embedding_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
num_classes=output_dim
)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
best_acc = 0.0
model.to(device)
for epoch in range(num_epochs):
model.train()
total_correct = 0
total_samples = 0
for batch in train_loader:
sent_a, sent_b, labels = batch
sent_a = sent_a.long().to(device).transpose(0, 1)
sent_b = sent_b.long().to(device).transpose(0, 1)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(sent_a, sent_b)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
preds = outputs.argmax(dim=1)
total_correct += (preds == labels).sum().item()
total_samples += labels.size(0)
epoch_acc = total_correct / total_samples
if epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(model.state_dict(), 'best_model.pth')
Inference Example
sentence_a = "我不爱吃剁椒鱼头,但是我爱吃鱼头"
sentence_b = "我爱吃土豆,但是不爱吃地瓜"
# Tokenize and convert to indices
seq_a = torch.tensor([[word2idx.get(w, 0) for w in sentence_a]])
seq_b = torch.tensor([[word2idx.get(w, 0) for w in sentence_b]])
# Pad and transpose
seq_a = pad_and_transpose(seq_a, max_len=20)
seq_b = pad_and_transpose(seq_b, max_len=20)
# Load model and predict
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()
with torch.no_grad():
logits = model(seq_a, seq_b)
pred = logits.argmax().item()
label_map = {0: "Not Matched", 1: "Matched"}
print(f"Prediction: {label_map[pred]}")