Building a Multi-Tier Caching System for LLM API Responses

Building a Multi-Tier Caching System for LLM API Responses

It was a Monday morning, and my coffee hadn't even kicked in when the alert hit. Our LLM API costs for the previous week had spiked by an alarming 40%, coinciding with a noticeable slowdown in our content generation endpoints. Users were complaining about longer wait times, and my wallet was starting to feel the pinch from the exponential growth of LLM invocations. We were hitting the APIs for every single request, even for prompts that were effectively identical. This wasn't sustainable, and it was clear I needed to implement a robust caching strategy, and fast.

My goal was ambitious: drastically reduce LLM API calls, slash costs, and improve response times without compromising the freshness of our generated content too much. I knew a single-layer cache wouldn't cut it. To truly optimize for both speed and cost, especially in a serverless environment like Cloud Run, a multi-tier approach was essential. I needed something that could provide lightning-fast responses for frequently accessed data, shared state across distributed instances, and a reliable fallback for less common, but still repeatable, queries.

The Caching Imperative: Why LLMs Demand Multi-Tier Solutions

Large Language Models are powerful, but their power comes at a cost – both in terms of monetary expense per token and the inherent latency of complex inference. Every time our application sends a prompt to an LLM provider, we incur both a financial transaction and network overhead, followed by the model's processing time. For applications with repetitive query patterns, this quickly becomes inefficient.

My initial thought was a simple cache. But diving deeper, I realized the nuances of our serverless architecture (Cloud Run instances spinning up and down, multiple instances serving traffic) and the varying nature of LLM responses (some highly static, some needing more frequent updates) demanded something more sophisticated. A single caching layer would either be too slow (if it was a remote cache) or too ephemeral (if it was local). I needed a hierarchy that prioritized speed for hot data, shared data across instances, and offered persistence for long-term cost savings.

Tier 1: The Blazing Fast In-Memory Cache (functools.lru_cache)

The first line of defense, and the fastest possible cache, is an in-memory cache. In our Python-based Cloud Run services, functools.lru_cache was the obvious choice. It's a built-in decorator that memoizes function results, storing them directly in the application's memory. For repeated calls with identical arguments, it returns the cached result almost instantaneously, bypassing network calls and even local computation.

The beauty of lru_cache is its simplicity and efficiency. Since each Cloud Run instance is an isolated process, this cache is local to that specific instance. While this means it's not shared across different instances, it's incredibly effective for requests that hit the same instance multiple times within its lifecycle. It's also crucial for mitigating the impact of Cloud Run Cold Start Optimization: From Seconds to Milliseconds, as frequently accessed functions can warm up quickly within an instance.

Implementation (Python)


import functools
import hashlib
import json
from typing import Dict, Any

# Max size for the in-memory cache
LRU_CACHE_MAX_SIZE = 1024

def _generate_cache_key(prompt: str, model_config: Dict[str, Any]) -> str:
    """Generates a consistent hash key for caching based on prompt and model config."""
    # Ensure model_config is hashable by converting to a sorted JSON string
    config_str = json.dumps(model_config, sort_keys=True)
    return hashlib.sha256(f"{prompt}-{config_str}".encode('utf-8')).hexdigest()

@functools.lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def get_llm_response_from_lru_cache(cache_key: str) -> Any:
    """
    Placeholder for retrieving from LRU cache.
    In a real scenario, the decorated function itself would be the LLM call.
    For this example, we're showing how the key would be used.
    """
    print(f"DEBUG: Checking LRU cache for key: {cache_key}")
    # In a real setup, the LLM call function would be decorated.
    # This function would then only execute if the key is not in cache.
    return None # Not found in this conceptual example

# Example of how you'd use it in practice wrapping an LLM call:
# @functools.lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
# def call_llm_api_cached(prompt: str, model_config: Dict[str, Any]) -> Dict[str, Any]:
#     key = _generate_cache_key(prompt, model_config)
#     # If this function is called with the same key, lru_cache will return
#     # the previous result without executing the LLM API call below.
#     print(f"INFO: Calling actual LLM API for prompt: {prompt[:50]}...")
#     # Simulate LLM API call
#     import time
#     time.sleep(2)
#     response = {"generated_text": f"Cached response for '{prompt}' with config {model_config}", "source": "LLM_API"}
#     return response

# Example usage (conceptual):
# prompt_1 = "Write a short poem about a cat."
# config_1 = {"model": "gemini-pro", "temperature": 0.7}
# key_1 = _generate_cache_key(prompt_1, config_1)
#
# # First call, will hit LLM (conceptually)
# # response_1 = call_llm_api_cached(prompt_1, config_1)
# # print(response_1)
#
# # Second call with same args, will hit LRU cache
# # response_2 = call_llm_api_cached(prompt_1, config_1)
# # print(response_2)

I set maxsize to 1024, a reasonable number for our typical request volume per instance, balancing memory usage and cache hit potential. For generating cache keys, it's crucial to consider all input parameters that influence the LLM's response, including the prompt, model name, temperature, top_p, etc. I serialize these into a consistent hash to ensure accurate cache hits.

Tier 2: The Shared, Distributed Cache (Redis/Memorystore)

While lru_cache is fast, it's ephemeral and not shared across different Cloud Run instances. This means if two different instances receive the same request, both would hit the LLM API. To achieve a higher overall cache hit rate and share state across our horizontally scaled services, a distributed cache was indispensable. I chose Google Cloud Memorystore for Redis, a fully managed, in-memory data store service.

Redis provides sub-millisecond latency for cache lookups and is perfect for storing LLM responses with a moderate Time-To-Live (TTL). It's shared across all instances of our service, meaning a cache hit on one instance benefits all others. This significantly reduced redundant LLM calls across our entire service, leading to substantial cost savings.

To connect securely and efficiently from Cloud Run, I ensured our Cloud Run service had appropriate VPC Access Connector configured to reach the private IP of the Memorystore instance. This keeps network traffic private and fast.

Implementation (Python with redis-py)


import redis
import os
import json
import hashlib
from typing import Dict, Any, Optional

# Redis connection details from environment variables
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_DB = int(os.getenv("REDIS_DB", 0))

# Initialize Redis client (using connection pool for efficiency)
_redis_client = None

def get_redis_client():
    global _redis_client
    if _redis_client is None:
        try:
            _redis_client = redis.Redis(
                host=REDIS_HOST, 
                port=REDIS_PORT, 
                db=REDIS_DB, 
                decode_responses=True, # Decodes responses to Python strings
                socket_connect_timeout=5,
                socket_timeout=5 # Read timeout
            )
            _redis_client.ping() # Test connection
            print("INFO: Successfully connected to Redis.")
        except redis.exceptions.ConnectionError as e:
            print(f"ERROR: Could not connect to Redis: {e}")
            _redis_client = None # Reset client if connection failed
    return _redis_client

def get_llm_response_from_redis(cache_key: str) -> Optional[Dict[str, Any]]:
    """Retrieves an LLM response from Redis."""
    client = get_redis_client()
    if not client:
        return None

    try:
        cached_response_str = client.get(cache_key)
        if cached_response_str:
            print(f"DEBUG: Redis cache hit for key: {cache_key}")
            return json.loads(cached_response_str)
        print(f"DEBUG: Redis cache miss for key: {cache_key}")
    except redis.exceptions.RedisError as e:
        print(f"ERROR: Error accessing Redis: {e}")
    return None

def set_llm_response_in_redis(cache_key: str, response: Dict[str, Any], ttl_seconds: int = 3600):
    """Stores an LLM response in Redis with a TTL."""
    client = get_redis_client()
    if not client:
        return

    try:
        client.setex(cache_key, ttl_seconds, json.dumps(response))
        print(f"DEBUG: Stored in Redis cache for key: {cache_key} with TTL {ttl_seconds}s")
    except redis.exceptions.RedisError as e:
        print(f"ERROR: Error writing to Redis: {e}")

# The overall caching function combines LRU and Redis
def get_or_set_llm_response(prompt: str, model_config: Dict[str, Any], llm_api_call_func) -> Dict[str, Any]:
    cache_key = _generate_cache_key(prompt, model_config)

    # 1. Check in-memory LRU cache
    lru_response = get_llm_response_from_lru_cache(cache_key)
    if lru_response:
        print("INFO: Served from LRU cache.")
        return lru_response

    # 2. Check distributed Redis cache
    redis_response = get_llm_response_from_redis(cache_key)
    if redis_response:
        print("INFO: Served from Redis cache.")
        # Populate LRU cache for subsequent local hits
        get_llm_response_from_lru_cache.cache_set(cache_key, redis_response) # Manually set LRU cache
        return redis_response

    # 3. Cache miss: Call LLM API
    print("INFO: Cache miss. Calling LLM API...")
    llm_response = llm_api_call_func(prompt, model_config) # This is your actual LLM call
    llm_response["source"] = "LLM_API"

    # Store in both caches
    set_llm_response_in_redis(cache_key, llm_response, ttl_seconds=3600) # 1 hour TTL
    get_llm_response_from_lru_cache.cache_set(cache_key, llm_response) # Manually set LRU cache
    
    return llm_response

# Note: For the lru_cache to work as a true decorator here, 
# the `llm_api_call_func` itself would ideally be decorated.
# The `cache_set` method is a conceptual way to manually populate it
# when the decorator isn't directly wrapping the `get_or_set_llm_response` logic.
# In a more advanced setup, I'd wrap the _actual_ LLM call function with lru_cache
# and then call that wrapped function from within get_or_set_llm_response.

I use decode_responses=True for convenience, so Redis returns Python strings directly. For production, I also implemented robust error handling and connection pooling to ensure resilience against Redis unavailability and efficient resource usage. The TTL (Time-To-Live) for Redis entries is critical. For our typical blog post generation, 1 hour felt like a good balance between freshness and cache hit rates. Some more static content might get a 24-hour TTL.

Tier 3: The Persistent, Long-Term Cache (Google Cloud Storage)

Even with Redis, some LLM responses might be very expensive, rarely accessed, but still valuable for long-term reference or audit. Also, Redis has a cost associated with its memory and operations. For truly long-lived, less frequently accessed responses, I added a persistent storage layer using Google Cloud Storage (GCS).

GCS acts as an archival cache. It's slower to retrieve from compared to Redis, but significantly cheaper for large volumes of data and offers indefinite storage. I use it for responses that could have a TTL of days, weeks, or even months. This tier is particularly useful for content that, once generated, is unlikely to change but might be requested again after the Redis TTL expires.

Implementation (Python with google-cloud-storage)


from google.cloud import storage
import os
import json
from typing import Dict, Any, Optional

GCS_BUCKET_NAME = os.getenv("GCS_CACHE_BUCKET", "autoblogger-llm-cache")

_gcs_client = None

def get_gcs_client():
    global _gcs_client
    if _gcs_client is None:
        _gcs_client = storage.Client()
    return _gcs_client

def get_llm_response_from_gcs(cache_key: str) -> Optional[Dict[str, Any]]:
    """Retrieves an LLM response from Google Cloud Storage."""
    client = get_gcs_client()
    bucket = client.bucket(GCS_BUCKET_NAME)
    blob = bucket.blob(f"{cache_key}.json") # Store as JSON file

    try:
        if blob.exists():
            print(f"DEBUG: GCS cache hit for key: {cache_key}")
            content = blob.download_as_text()
            return json.loads(content)
        print(f"DEBUG: GCS cache miss for key: {cache_key}")
    except Exception as e:
        print(f"ERROR: Error accessing GCS for key {cache_key}: {e}")
    return None

def set_llm_response_in_gcs(cache_key: str, response: Dict[str, Any]):
    """Stores an LLM response in Google Cloud Storage."""
    client = get_gcs_client()
    bucket = client.bucket(GCS_BUCKET_NAME)
    blob = bucket.blob(f"{cache_key}.json")

    try:
        blob.upload_from_string(json.dumps(response), content_type="application/json")
        print(f"DEBUG: Stored in GCS cache for key: {cache_key}")
    except Exception as e:
        print(f"ERROR: Error writing to GCS for key {cache_key}: {e}")

# Refined overall caching function with GCS
def get_or_set_llm_response_multi_tier(prompt: str, model_config: Dict[str, Any], llm_api_call_func) -> Dict[str, Any]:
    cache_key = _generate_cache_key(prompt, model_config)

    # 1. Check in-memory LRU cache
    lru_response = get_llm_response_from_lru_cache(cache_key)
    if lru_response:
        print("INFO: Served from LRU cache.")
        return lru_response

    # 2. Check distributed Redis cache
    redis_response = get_llm_response_from_redis(cache_key)
    if redis_response:
        print("INFO: Served from Redis cache.")
        get_llm_response_from_lru_cache.cache_set(cache_key, redis_response)
        return redis_response

    # 3. Check persistent GCS cache
    gcs_response = get_llm_response_from_gcs(cache_key)
    if gcs_response:
        print("INFO: Served from GCS cache.")
        # Populate Redis and LRU for faster subsequent access
        set_llm_response_in_redis(cache_key, gcs_response, ttl_seconds=3600)
        get_llm_response_from_lru_cache.cache_set(cache_key, gcs_response)
        return gcs_response

    # 4. Cache miss: Call LLM API
    print("INFO: Cache miss across all tiers. Calling LLM API...")
    llm_response = llm_api_call_func(prompt, model_config)
    llm_response["source"] = "LLM_API"

    # Store in all caches
    set_llm_response_in_redis(cache_key, llm_response, ttl_seconds=3600)
    set_llm_response_in_gcs(cache_key, llm_response) # GCS has no inherent TTL, managed by object lifecycle rules
    get_llm_response_from_lru_cache.cache_set(cache_key, llm_response)
    
    return llm_response

# Example LLM API call simulation
def simulate_llm_api_call(prompt: str, model_config: Dict[str, Any]) -> Dict[str, Any]:
    import time
    time.sleep(2) # Simulate network and processing delay
    return {"generated_text": f"Generated content for '{prompt[:50]}...' with {model_config['model']}", "tokens_used": 150}

# Example usage:
# prompt_example = "Explain quantum entanglement in simple terms."
# config_example = {"model": "gemini-1.5-pro", "temperature": 0.5}
#
# print("\n--- First Request ---")
# response_tier_1 = get_or_set_llm_response_multi_tier(prompt_example, config_example, simulate_llm_api_call)
# print(f"Response (1): {response_tier_1['generated_text'][:70]}... Source: {response_tier_1['source']}")
#
# print("\n--- Second Request (same instance, LRU hit expected) ---")
# response_tier_2 = get_or_set_llm_response_multi_tier(prompt_example, config_example, simulate_llm_api_call)
# print(f"Response (2): {response_tier_2['generated_text'][:70]}... Source: {response_tier_2['source']}")
#
# # To simulate a new instance or LRU eviction, you'd need to restart/reset.
# # For Redis/GCS, the cache remains.

What I Learned / The Challenge

Building this multi-tier caching system wasn't without its challenges. The primary learning points revolved around:

  1. Cache Invalidation is Hard: Deciding on appropriate TTLs for each tier was a balancing act. Too short, and we lose hit rates; too long, and we risk serving stale data. For critical, rapidly changing information, I opted for shorter TTLs or even explicit invalidation mechanisms (e.g., triggering a Redis DEL command when source data changes). For our content generation, where "freshness" often means "good enough for a few hours," longer TTLs worked well.
  2. Key Generation Consistency: Ensuring that the cache key generation function was deterministic and comprehensive was paramount. Any slight variation in prompt or model configuration (like a different temperature setting) must result in a different key, otherwise, we risk serving incorrect cached responses.
  3. Error Handling and Fallbacks: What happens if Redis is down? Or GCS is experiencing issues? My caching logic includes robust try-except blocks to gracefully handle cache failures and fall back to the next tier or, ultimately, the LLM API. The goal is to improve performance, not to break the application.
  4. Monitoring Cache Hit Rates: Without clear metrics on cache hit rates for each tier, it's impossible to know if the system is effective. I instrumented our code to log hits and misses for LRU, Redis, and GCS, pushing these metrics to Cloud Monitoring. This visibility helps me fine-tune TTLs and identify areas for further optimization.

Metrics and Impact

The results were transformative:

  • Cost Reduction: Within the first month, our LLM API costs dropped by approximately 65%. For certain high-traffic endpoints with repetitive prompts (like generating meta descriptions for similar product categories), the reduction was even more dramatic, reaching over 80%.
  • Latency Improvement: Average response times for cached requests plummeted from an average of 2.5 seconds (for direct LLM calls) to under 50 milliseconds for Redis hits, and even faster for LRU hits (often <10ms). Even GCS hits, while slower than Redis, still returned in ~200-300ms, a vast improvement over hitting the LLM API.
  • Increased Throughput: By offloading a significant portion of requests from the LLM APIs, our services could handle a much higher volume of concurrent user requests without scaling LLM API usage proportionally. This directly translated to better user experience and reduced operational load on the LLM providers.

This multi-tier approach proved to be a critical architectural decision, allowing us to scale our LLM-powered features aggressively while keeping costs and performance in check. It’s a testament to the power of thoughtful system design.

Related Reading

Moving Forward

While the current multi-tier caching system has yielded significant benefits, my work isn't done. I'm actively exploring the integration of semantic caching, which uses vector embeddings to identify and serve responses for *semantically similar* prompts, not just exact matches. This would unlock even greater cache hit rates, especially for nuanced user queries. Additionally, I'm looking into dynamic TTLs based on content type or historical access patterns, and exploring options for automatic cache warm-up for critical paths. The journey of optimizing LLM interactions is continuous, and I'm excited for the next set of challenges.

Comments

Popular posts from this blog

Optimizing LLM API Latency: Async, Streaming, and Pydantic in Production

How I Built a Semantic Cache to Reduce LLM API Costs

How I Squeezed LLM Inference onto a Raspberry Pi for Local AI