While LangChain provides several built-in memory implementations, you may encounter scenarios where you need a custom memory type tailored to your specific application requirements. This guide demonstrates how to implement a custom memory class and integrate it with LangChain's ConversationChain.
Prerequisites
For this implementation, you'll need the folllowing dependencies:
from langchain import OpenAI, ConversationChain
from langchain.schema import BaseMemory
from pydantic import BaseModel
from typing import List, Dict, Any
Example 1: Entity-Based Memory Using spaCy
This custom memory class uses spaCy for named entity recognition (NER) to extract and store entities in a simple dictionary structure. During conversations, the system identifies entities from input text and includes relevant information in the context window.
Note: This implementation is intentionally simplified for demonstration purposes and may not be suitable for production environments without additional robustness.
First, ensure spaCy and its language model are instaled:
# !pip install spacy
# !python -m spacy download en_core_web_sm
import spacy
nlp = spacy.load("en_core_web_sm")
The following class extends BaseMemory to create a custom entity memory system:
class EntityExtractorMemory(BaseMemory, BaseModel):
"""Memory class that stores information about extracted entities."""
entity_store: dict = {}
context_key: str = "entity_context"
def clear(self) -> None:
self.entity_store = {}
@property
def memory_variables(self) -> List[str]:
"""Return variables that should be injected into prompts."""
return [self.context_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load stored entities for context injection."""
# Process input text through spaCy pipeline
input_key = list(inputs.keys())[0]
doc = nlp(inputs[input_key])
# Retrieve known information for detected entities
known_entities = [
self.entity_store[str(entity)]
for entity in doc.ents
if str(entity) in self.entity_store
]
# Return formatted entity information
return {self.context_key: "\n".join(known_entities)}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Persist conversation context to memory."""
input_key = list(inputs.keys())[0]
text = inputs[input_key]
doc = nlp(text)
# Store or append information for each detected entity
for entity in doc.ents:
entity_text = str(entity)
if entity_text in self.entity_store:
self.entity_store[entity_text] += f"\n{text}"
else:
self.entity_store[entity_text] = text
Integrating with ConversationChain
Define a prompt template to incorporate entity context:
from langchain.prompts.prompt import PromptTemplate
TEMPLATE = """You are an informative AI assistant engaged in a friendly conversation.
When relevant, use the entity information provided below to provide contextually aware responses.
Entity Information:
{entity_context}
Conversation:
Human: {input}
AI:"""
prompt = PromptTemplate(
input_variables=["entity_context", "input"],
template=TEMPLATE
)
Initialize the conversation chain:
llm = OpenAI(temperature=0)
conversation = ConversationChain(
llm=llm,
prompt=prompt,
verbose=True,
memory=EntityExtractorMemory()
)
Testing the Implementation
First interaction (no prior entity data):
conversation.predict(input="Sarah enjoys studying artificial intelligence")
Output demonstrates that entity information is initially empty. The system responds naturally based on the current input.
Second interaction (entity information available):
conversation.predict(
input="What subject did Sarah enjoy studying most?"
)
The response shows that the system correctly retrieves stored information about Sarah and incorporates it into the context window.
Example 2: Recent-K JSON Field Memory
This custom implementation extracts a specific field from JSON responses and maintains a rolling window of recent conversations. It proves useful when you need to track particular data points across multiple interactions.
Consider a scenario where LLM outputs follow this JSON structure:
{
"current_step": "step_name",
"result_info": {
"step_items": [...],
"completion_flag": true/false,
"next_step": "upcoming_step"
},
"response_text": "assistant message content"
}
The following memory class captures the "response_text" field from recent interactions:
from langchain.schema import BaseMemory
from typing import List, Dict, Any
class ConversationWindowMemory(BaseMemory):
"""Memory that preserves the last K JSON response fields."""
conversation_buffer: list = []
output_key: str = "conversation_history"
user_label: str = "User"
assistant_label: str = "Assistant"
max_turns: int = 5
def clear(self) -> None:
self.conversation_buffer = []
@property
def memory_variables(self) -> List[str]:
"""Specify which variables to include in prompts."""
return [self.output_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Extract and format recent response fields."""
formatted_history = ""
# Process the most recent K conversations
for interaction in self.conversation_buffer[-self.max_turns:]:
response_data = eval(interaction["output"]["response"])
if "response_text" in response_data.keys():
user_line = f"{self.user_label}: " + interaction["input"]['input']
assistant_line = f"{self.assistant_label}: " + response_data["response_text"]
formatted_history += "\n" + "\n".join([user_line, assistant_line])
return {self.output_key: formatted_history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""Store conversation turns, maintaining window size."""
turn_record = {
"input": inputs,
"output": outputs
}
self.conversation_buffer.append(turn_record)
# Remove oldest turn if buffer exceeds K
if len(self.conversation_buffer) > self.max_turns:
self.conversation_buffer.pop(0)
Configuration and Usage
Create a prompt template suited to your application:
from langchain.prompts.prompt import PromptTemplate
CUSTOM_TEMPLATE = """
You are a skilled meeting facilitator providing structured guidance.
【Response Format】
{
"step_name": "current phase",
"analysis": {
"end_target": "participant's stated goal",
"reasoning": "feasibility assessment",
"status": "achieved/not achieved",
"type": "meeting classification",
"proceed": "yes/no",
"next_phase": "upcoming step name"
},
"message": "facilitator response content"
}
【Interaction History】
{conversation_history}
【Current Input】
Participant: {input}
Assistant Response:"""
prompt = PromptTemplate(
input_variables=["conversation_history", "input"],
template=CUSTOM_TEMPLATE
)
Initialize with custom parameters:
chat = OpenAI(temperature=0, model_name="gpt-3.5-turbo")
chain = ConversationChain(
llm=chat,
prompt=prompt,
memory=ConversationWindowMemory(
assistant_label="Facilitator",
user_label="Participant",
max_turns=2
),
verbose=True,
)
Verification of Rolling Window Behavior
Initial interaction stores the first response field:
chain.predict(input="Let's begin")
Subsequent interactions accumulate:
chain.predict(input="My name is John")
chain.predict(input="I'm working on a project")
After reaching the window limit, the oldest interaction is automatically removed, demonstrating the rolling window behavior.
Key Implementation Patterns
Custom memory classes in LangChain must implement the following methods:
- memory_variables: Property returning a list of keys that the memory contributes to the prompt.
- load_memory_variables: Method that retrieves stored information and returns it as a dictionary mapping keys to values.
- save_context: Method that receives input/output pairs and persists them to the memory store.
- clear: Method that resets the memory to its initial state.
By subclassing BaseMemory and implementing these methods, you can create memory systems tailored to any conversation management requirements.