Building Custom Memory Classes in LangChain for Conversation Management

While LangChain provides several built-in memory implementations, you may encounter scenarios where you need a custom memory type tailored to your specific application requirements. This guide demonstrates how to implement a custom memory class and integrate it with LangChain's ConversationChain.

Prerequisites

For this implementation, you'll need the folllowing dependencies:

from langchain import OpenAI, ConversationChain
from langchain.schema import BaseMemory
from pydantic import BaseModel
from typing import List, Dict, Any

Example 1: Entity-Based Memory Using spaCy

This custom memory class uses spaCy for named entity recognition (NER) to extract and store entities in a simple dictionary structure. During conversations, the system identifies entities from input text and includes relevant information in the context window.

Note: This implementation is intentionally simplified for demonstration purposes and may not be suitable for production environments without additional robustness.

First, ensure spaCy and its language model are instaled:

# !pip install spacy
# !python -m spacy download en_core_web_sm

import spacy

nlp = spacy.load("en_core_web_sm")

The following class extends BaseMemory to create a custom entity memory system:

class EntityExtractorMemory(BaseMemory, BaseModel):
    """Memory class that stores information about extracted entities."""
    
    entity_store: dict = {}
    context_key: str = "entity_context"
    
    def clear(self) -> None:
        self.entity_store = {}
    
    @property
    def memory_variables(self) -> List[str]:
        """Return variables that should be injected into prompts."""
        return [self.context_key]
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """Load stored entities for context injection."""
        # Process input text through spaCy pipeline
        input_key = list(inputs.keys())[0]
        doc = nlp(inputs[input_key])
        
        # Retrieve known information for detected entities
        known_entities = [
            self.entity_store[str(entity)] 
            for entity in doc.ents 
            if str(entity) in self.entity_store
        ]
        
        # Return formatted entity information
        return {self.context_key: "\n".join(known_entities)}
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """Persist conversation context to memory."""
        input_key = list(inputs.keys())[0]
        text = inputs[input_key]
        doc = nlp(text)
        
        # Store or append information for each detected entity
        for entity in doc.ents:
            entity_text = str(entity)
            if entity_text in self.entity_store:
                self.entity_store[entity_text] += f"\n{text}"
            else:
                self.entity_store[entity_text] = text

Integrating with ConversationChain

Define a prompt template to incorporate entity context:

from langchain.prompts.prompt import PromptTemplate

TEMPLATE = """You are an informative AI assistant engaged in a friendly conversation. 
When relevant, use the entity information provided below to provide contextually aware responses.

Entity Information:
{entity_context}

Conversation:
Human: {input}
AI:"""

prompt = PromptTemplate(
    input_variables=["entity_context", "input"], 
    template=TEMPLATE
)

Initialize the conversation chain:

llm = OpenAI(temperature=0)
conversation = ConversationChain(
    llm=llm, 
    prompt=prompt, 
    verbose=True, 
    memory=EntityExtractorMemory()
)

Testing the Implementation

First interaction (no prior entity data):

conversation.predict(input="Sarah enjoys studying artificial intelligence")

Output demonstrates that entity information is initially empty. The system responds naturally based on the current input.

Second interaction (entity information available):

conversation.predict(
    input="What subject did Sarah enjoy studying most?"
)

The response shows that the system correctly retrieves stored information about Sarah and incorporates it into the context window.

Example 2: Recent-K JSON Field Memory

This custom implementation extracts a specific field from JSON responses and maintains a rolling window of recent conversations. It proves useful when you need to track particular data points across multiple interactions.

Consider a scenario where LLM outputs follow this JSON structure:

{
  "current_step": "step_name",
  "result_info": {
    "step_items": [...],
    "completion_flag": true/false,
    "next_step": "upcoming_step"
  },
  "response_text": "assistant message content"
}

The following memory class captures the "response_text" field from recent interactions:

from langchain.schema import BaseMemory
from typing import List, Dict, Any

class ConversationWindowMemory(BaseMemory):
    """Memory that preserves the last K JSON response fields."""
    
    conversation_buffer: list = []
    output_key: str = "conversation_history"
    user_label: str = "User"
    assistant_label: str = "Assistant"
    max_turns: int = 5
    
    def clear(self) -> None:
        self.conversation_buffer = []
    
    @property
    def memory_variables(self) -> List[str]:
        """Specify which variables to include in prompts."""
        return [self.output_key]
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """Extract and format recent response fields."""
        formatted_history = ""
        
        # Process the most recent K conversations
        for interaction in self.conversation_buffer[-self.max_turns:]:
            response_data = eval(interaction["output"]["response"])
            
            if "response_text" in response_data.keys():
                user_line = f"{self.user_label}: " + interaction["input"]['input']
                assistant_line = f"{self.assistant_label}: " + response_data["response_text"]
                formatted_history += "\n" + "\n".join([user_line, assistant_line])
        
        return {self.output_key: formatted_history}
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """Store conversation turns, maintaining window size."""
        turn_record = {
            "input": inputs,
            "output": outputs
        }
        
        self.conversation_buffer.append(turn_record)
        
        # Remove oldest turn if buffer exceeds K
        if len(self.conversation_buffer) > self.max_turns:
            self.conversation_buffer.pop(0)

Configuration and Usage

Create a prompt template suited to your application:

from langchain.prompts.prompt import PromptTemplate

CUSTOM_TEMPLATE = """
You are a skilled meeting facilitator providing structured guidance.

【Response Format】
{
  "step_name": "current phase",
  "analysis": {
    "end_target": "participant's stated goal",
    "reasoning": "feasibility assessment",
    "status": "achieved/not achieved",
    "type": "meeting classification",
    "proceed": "yes/no",
    "next_phase": "upcoming step name"
  },
  "message": "facilitator response content"
}

【Interaction History】
{conversation_history}

【Current Input】
Participant: {input}
Assistant Response:"""

prompt = PromptTemplate(
    input_variables=["conversation_history", "input"], 
    template=CUSTOM_TEMPLATE
)

Initialize with custom parameters:

chat = OpenAI(temperature=0, model_name="gpt-3.5-turbo")
chain = ConversationChain(
    llm=chat,
    prompt=prompt,
    memory=ConversationWindowMemory(
        assistant_label="Facilitator",
        user_label="Participant",
        max_turns=2
    ),
    verbose=True,
)

Verification of Rolling Window Behavior

Initial interaction stores the first response field:

chain.predict(input="Let's begin")

Subsequent interactions accumulate:

chain.predict(input="My name is John")
chain.predict(input="I'm working on a project")

After reaching the window limit, the oldest interaction is automatically removed, demonstrating the rolling window behavior.

Key Implementation Patterns

Custom memory classes in LangChain must implement the following methods:

  • memory_variables: Property returning a list of keys that the memory contributes to the prompt.
  • load_memory_variables: Method that retrieves stored information and returns it as a dictionary mapping keys to values.
  • save_context: Method that receives input/output pairs and persists them to the memory store.
  • clear: Method that resets the memory to its initial state.

By subclassing BaseMemory and implementing these methods, you can create memory systems tailored to any conversation management requirements.

Tags: LangChain python custom-memory conversation-chain spacy

Posted on Wed, 27 May 2026 21:32:54 +0000 by websitesca