Disk-based caching with asset management in Python
A disk-based caching mechanism in Python that support storing and retrieving files in native formats, like Parquet, Arrow, Numpy, and JSON with Pickle fallback.
The Python DiskCache library is my go-to solution for disk-based caching. It uses a SQlite database to store small data, and stores larger data in pickle files on disk. It has expiration and eviction policies and a convenient @memoize
decorator for caching function results.
I use Polars and Numpy for data processing and there are more efficient native formats for storing data, such as Parquet, Arrow, and Numpy. Below is an extension to DiskCache that adds support for these formats using a new @asset
decorator. Now the Cache
class will use duck-typing to determine the type of data being returned by the decorated function and store it in the appropriate format.
cache = Cache(directory="./cache")
@cache.asset()
def filter_weather_stations(df: pl.DataFrame) -> pl.DataFrame:
return df.filter(pl.col("country_code") == "US")
The @asset
decorator will automatically detect the return data is a Polars DataFrame and store it in .parquet
format. PyArrow tables are stored as .arrow
as Arrow IPC format. Numpy array will be stored in .npy
format. If the data is a dictionary or list, it will be stored in JSON format. For any other data type, it will fall back to using Pickle. The Cache
class can be extended to support more formats.
This is a complete-reimplementation of the DiskCache core Cache
class to add the assets decorator, additional bug-fixes and improvements like hardening against edge cases, including race conditions, file name collisions, and thread safety. The complete implementation is at the end of this note.
Table of Contents:
- Features
- Basic Usage
- Expiration Examples
- Memoization Decorator
- Asset Handling with Native Formats
- Extensible Asset Handlers
- Clearing Cache by Name or Function
- Storage Backends
- Eviction Policies
- Pickle Fallback
- Advanced Configuration
- Implementation code
Features
- Disk and SQLite Storage: Hybrid storage using SQLite database for metadata and filesystem for large assets
- Expiration Support: Set TTL (time-to-live) for cached items with automatic cleanup
- Memoization: Function result caching with the @memoize decorator
- Asset Handling: Store data in native formats (JSON, Parquet, Arrow, NumPy, etc.)
- Extensible Handlers: Custom asset format handlers with automatic fallback to pickle
- Key Management: Clear specific keys, function caches, or all items
- Eviction Policies: LRU, LFU, least-recently-stored, or none
- Transaction Safety: Thread and process-safe operations with SQLite WAL mode
Basic Usage
Create a cache instance and store/retrieve data:
from explore.cache import Cache
# Create cache with default settings
cache = Cache()
# Store with expiration (3600 seconds = 1 hour)
cache.set("user:123", {"name": "John", "email": "[email protected]"}, expire=3600)
# Retrieve data
user_data = cache.get("user:123")
print(user_data) # {'name': 'John', 'email': '[email protected]'}
# Check if key exists
if "user:123" in cache:
print("User data is cached")
# Dictionary-style access
cache["session:abc"] = "session_data"
session = cache["session:abc"]
Expiration Examples
import time
# Set with 5 second expiration
cache.set("temp_key", "temporary_value", expire=5)
print(cache.get("temp_key")) # "temporary_value"
time.sleep(6)
print(cache.get("temp_key")) # None (expired)
# Touch to extend expiration
cache.set("extend_key", "value", expire=10)
cache.touch("extend_key", expire=3600) # Extend to 1 hour
Memoization Decorator
Cache expensive function results automatically:
@cache.memoize(expire=3600, typed=True)
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)
# First call computes and caches
result1 = fibonacci(100)
# Subsequent calls return cached result
result2 = fibonacci(100) # Much faster
# Access cache key for manual operations
key = fibonacci.__cache_key__(100)
cached_value = cache[key]
# Clear specific function's cache
cache.clear(fibonacci)
Asset Handling with Native Formats
Store data in native formats instead of pickle:
import pandas as pd
import polars as pl
import numpy as np
import pyarrow as pa
# Pandas DataFrame stored as Parquet
@cache.asset(expire=3600)
def get_sales_data():
return pd.DataFrame({'sales': [100, 200, 300], 'region': ['A', 'B', 'C']})
sales = get_sales_data() # Stored as .parquet file
# Polars DataFrame stored as Parquet
@cache.asset(expire=3600)
def get_analytics_data():
return pl.DataFrame({'metric': [1.1, 2.2, 3.3], 'date': ['2023-01', '2023-02', '2023-03']})
analytics = get_analytics_data()
# NumPy array stored as .npy file
@cache.asset(expire=3600)
def get_model_weights():
return np.random.random((1000, 100))
weights = get_model_weights()
# PyArrow Table stored as Arrow IPC format
@cache.asset(expire=3600, format="arrow")
def get_arrow_table():
return pa.table({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']})
table = get_arrow_table()
# JSON for dicts/lists (when JSON-serializable)
@cache.asset(expire=3600)
def get_config():
return {"database": "postgres", "port": 5432, "features": ["auth", "cache"]}
config = get_config() # Stored as .json file
# Force specific format
@cache.asset(expire=3600, format="parquet")
def get_mixed_data():
# Even if data could be JSON, force Parquet format
return pl.DataFrame({'values': [1, 2, 3]})
Extensible Asset Handlers
Create custom handlers for new data formats:
class CSVHandler(AssetHandler):
format_name = "csv"
def can_handle(self, data):
return hasattr(data, 'to_csv') # Pandas-like objects
def save(self, data, path):
data.to_csv(path, index=False)
def load(self, path):
import pandas as pd
return pd.read_csv(path)
def is_available(self):
try:
import pandas
return True
except ImportError:
return False
# Register custom handler
cache.register_asset_handler(CSVHandler())
# Now DataFrames can be stored as CSV
@cache.asset(format="csv")
def get_report():
return pd.DataFrame({'report': ['Q1', 'Q2'], 'revenue': [100000, 120000]})
Clearing Cache by Name or Function
# Clear all cache
cache.clear()
# Clear specific key
cache.clear("user:123")
# Clear specific function's cached results
cache.clear(expensive_function)
# Clear multiple items
cache.clear("key1", "key2", my_function)
# Clear with list
cache.clear(["session:abc", "temp:xyz"])
Storage Backends
# SQLite database for metadata and small values
cache = Cache(
directory="/path/to/cache",
dbname="my_cache.db",
sqlite_timeout=60,
min_file_size=32768 # Store values < 32KB in SQLite, larger as files
)
# Large files stored on disk with appropriate extensions
# Small values stored as BLOBs in SQLite database
# Automatic file/database decision based on size
Eviction Policies
# Least Recently Used (LRU)
cache = Cache(eviction_policy="least-recently-used", size_limit=1024*1024*1024) # 1GB
# Least Frequently Used (LFU)
cache = Cache(eviction_policy="least-frequently-used")
# Least Recently Stored
cache = Cache(eviction_policy="least-recently-stored")
# No eviction
cache = Cache(eviction_policy="none")
Pickle Fallback
When native format handlers fail or aren’t available, automatic fallback to pickle:
# Custom object without specific handler
class CustomData:
def __init__(self, value):
self.value = value
@cache.asset(expire=3600)
def get_custom_object():
return CustomData("important_data")
# Automatically falls back to pickle storage
obj = get_custom_object()
Advanced Configuration
cache = Cache(
directory="/var/cache/myapp",
dbname="cache.db",
eviction_policy="least-recently-used",
size_limit=2**30, # 1GB total cache size
cull_limit=100, # Remove 100 items when culling
min_file_size=2**15, # 32KB threshold for file vs database storage
sqlite_cache_size=8192, # SQLite page cache
sqlite_mmap_size=2**26, # 64MB memory mapping
asset_handlers=[CustomHandler(), JSONHandler(), PickleHandler()]
)
Implementation code
This is a zero-dependency implementation for disk-based caching with native format support. It includes the Cache
class, AssetHandler
base class, and various asset handlers for JSON, Parquet, Arrow, Numpy, and Pickle.
import os
import io
import time
import errno
import shutil
import codecs
import pickle
import sqlite3
import hashlib
import tempfile
import threading
import contextlib
import typing as t
import pickletools
from pathlib import Path
from functools import partial, wraps
# Type variables for preserving function signatures
P = t.ParamSpec('P')
T = t.TypeVar('T')
EVICTION_POLICY = {
"none": {"init": None, "get": None, "cull": None},
"least-recently-stored": {
"init": "CREATE INDEX IF NOT EXISTS Cache_store_time ON Cache (store_time)",
"get": None,
"cull": "SELECT {fields} FROM Cache ORDER BY store_time LIMIT ?",
},
"least-recently-used": {
"init": "CREATE INDEX IF NOT EXISTS Cache_access_time ON Cache (access_time)",
"get": "access_time = {now}",
"cull": "SELECT {fields} FROM Cache ORDER BY access_time LIMIT ?",
},
"least-frequently-used": {
"init": "CREATE INDEX IF NOT EXISTS Cache_access_count ON Cache (access_count)",
"get": "access_count = access_count + 1",
"cull": "SELECT {fields} FROM Cache ORDER BY access_count LIMIT ?",
},
}
MODE_NONE = 0
MODE_RAW = 1
MODE_BINARY = 2
MODE_TEXT = 3
MODE_PICKLE = 4
class AssetHandler:
"""Base class for asset format handlers.
Subclasses should override format_name, can_handle, save, and load methods.
"""
format_name: str = "base"
def can_handle(self, data: t.Any) -> bool:
"""Check if this handler can process the given data.
Args:
data: The data to check
Returns:
bool: True if this handler can process the data
"""
raise NotImplementedError
def save(self, data: t.Any, path: str) -> None:
"""Save data to file.
Args:
data: The data to save
path: The file path to save to
"""
raise NotImplementedError
def load(self, path: str) -> t.Any:
"""Load data from file.
Args:
path: The file path to load from
Returns:
The loaded data
"""
raise NotImplementedError
def is_available(self) -> bool:
"""Check if required dependencies are available.
Returns:
bool: True if all dependencies are available
"""
return True
class PickleHandler(AssetHandler):
"""Default pickle handler for any Python object."""
format_name = "pickle"
def can_handle(self, data: t.Any) -> bool:
return True # Can handle anything
def save(self, data: t.Any, path: str) -> None:
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
def load(self, path: str) -> t.Any:
with open(path, "rb") as f:
return pickle.load(f)
class JSONHandler(AssetHandler):
"""JSON handler for dicts and lists."""
format_name = "json"
def can_handle(self, data: t.Any) -> bool:
if not isinstance(data, (dict, list)):
return False
try:
import json
json.dumps(data)
return True
except (ImportError, TypeError, ValueError):
return False
def save(self, data: t.Any, path: str) -> None:
import json
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, default=str, sort_keys=False, separators=(",", ": "))
def load(self, path: str) -> t.Any:
import json
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def is_available(self) -> bool:
try:
import json # noqa: F401
return True
except ImportError:
return False
class PolarsHandler(AssetHandler):
"""Polars DataFrame handler."""
format_name = "parquet"
def can_handle(self, data: t.Any) -> bool:
try:
return (
hasattr(data, "__class__")
and data.__class__.__module__ == "polars.dataframe.frame"
and data.__class__.__name__ == "DataFrame"
)
except Exception:
return False
def save(self, data: t.Any, path: str) -> None:
data.write_parquet(path)
def load(self, path: str) -> t.Any:
import polars as pl
return pl.read_parquet(path)
def is_available(self) -> bool:
try:
import polars # noqa: F401
return True
except ImportError:
return False
class PandasHandler(AssetHandler):
"""Pandas DataFrame handler."""
format_name = "parquet"
def can_handle(self, data: t.Any) -> bool:
try:
return (
hasattr(data, "__class__")
and data.__class__.__module__ == "pandas.core.frame"
and data.__class__.__name__ == "DataFrame"
)
except Exception:
return False
def save(self, data: t.Any, path: str) -> None:
data.to_parquet(path)
def load(self, path: str) -> t.Any:
import pandas as pd
return pd.read_parquet(path)
def is_available(self) -> bool:
try:
import pandas # noqa: F401
return True
except ImportError:
return False
class NumpyHandler(AssetHandler):
"""NumPy array handler."""
format_name = "numpy"
def can_handle(self, data: t.Any) -> bool:
try:
return (
hasattr(data, "__class__")
and data.__class__.__module__ == "numpy"
and data.__class__.__name__ == "ndarray"
)
except Exception:
return False
def save(self, data: t.Any, path: str) -> None:
import numpy as np
# Remove .npy extension if present since np.save adds it automatically
if path.endswith('.npy'):
path = path[:-4]
np.save(path, data)
def load(self, path: str) -> t.Any:
import numpy as np
# np.save automatically adds .npy, so we need to handle both cases
if not path.endswith('.npy') and not os.path.exists(path):
path = path + '.npy'
return np.load(path)
def is_available(self) -> bool:
try:
import numpy # noqa: F401
return True
except ImportError:
return False
class ArrowHandler(AssetHandler):
"""PyArrow Table handler."""
format_name = "arrow"
def can_handle(self, data: t.Any) -> bool:
try:
# Check for PyArrow Table
return (
hasattr(data, "__class__")
and data.__class__.__module__ == "pyarrow.lib"
and data.__class__.__name__ == "Table"
)
except Exception:
return False
def save(self, data: t.Any, path: str) -> None:
import pyarrow as pa
# Save as Arrow IPC format (Arrow file)
with pa.OSFile(path, "wb") as sink:
with pa.ipc.new_file(sink, data.schema) as writer:
writer.write_table(data)
def load(self, path: str) -> t.Any:
import pyarrow as pa
# Read Arrow IPC format file
with pa.memory_map(path) as source:
return pa.ipc.open_file(source).read_all()
def is_available(self) -> bool:
try:
import pyarrow # noqa: F401
return True
except ImportError:
return False
def delete(path: str | Path) -> None:
try:
if os.path.isdir(path):
shutil.rmtree(path, ignore_errors=True)
else:
if os.path.exists(path):
os.remove(path)
except Exception:
pass
class _Constant(tuple):
"""Pretty display of immutable constant."""
def __new__(cls, name):
return tuple.__new__(cls, (name,))
def __repr__(self):
return "%s" % self[0]
ENOVAL = _Constant("ENOVAL")
UNKNOWN = _Constant("UNKNOWN")
def sqlite_execute_with_retry(conn: sqlite3.Connection, statement: str, parameters: t.Iterable = ()) -> sqlite3.Cursor:
"""Execute a SQL statement with retry on database lock.
Re-try the statement if the error is "database is locked" for up to 60 seconds.
Args:
conn: SQLite connection.
statement: SQL statement.
parameters: SQL statement parameters.
Returns:
sqlite3.Cursor: a cursor object
Raises:
sqlite3.OperationalError: if the error is not "database is locked".
TimeoutError: if the database is locked for more than 60 seconds.
"""
start = time.time()
while True:
try:
return conn.execute(statement, parameters) # type: ignore[no-untyped-call]
except sqlite3.OperationalError as exc:
if str(exc) != "database is locked":
raise # re-raise the original exception
diff = time.time() - start
if diff > 60:
raise TimeoutError("SQLite database is locked for more than 60 seconds") from None
time.sleep(0.001)
def full_name(func):
"""Return full name of `func` by adding the module and function name."""
return func.__module__ + "." + func.__qualname__
def args_to_key(base: tuple, args: tuple, kwargs: dict, typed: bool | None, ignore: tuple | None) -> tuple:
"""Create cache key out of function arguments.
Args:
base: base of key
args: function arguments
kwargs: function keyword arguments
typed: include types in cache key
ignore: positional or keyword args to ignore
Returns:
cache key tuple
"""
ignore = ignore or ()
args = tuple(arg for index, arg in enumerate(args) if index not in ignore)
key = base + args + (None,)
if kwargs:
kwargs = {key: val for key, val in kwargs.items() if key not in ignore}
sorted_items = sorted(kwargs.items())
for item in sorted_items:
key += item
if typed:
key += tuple(type(arg) for arg in args)
if kwargs:
key += tuple(type(value) for _, value in sorted_items)
return key
class Cache:
"""Disk and file based caching.
This class provides a cache that stores data on disk using SQLite as the backend.
It supports various eviction policies, file storage, and transaction management.
It can store data in different formats using asset handlers, allowing for efficient
storage and retrieval of data assets such as JSON, Polars DataFrames, Pandas DataFrames,
NumPy arrays, and Python objects using Pickle.
Args:
directory: Directory to store cache files.
sqlite_timeout: SQLite connection timeout in seconds.
dbname: SQLite database name.
pickle_protocol: Pickle protocol version.
eviction_policy: Eviction policy for the cache. Options are "none", "least-recently-stored",
"least-recently-used", "least-frequently-used". Default is "least-recently-stored".
cull_limit: Number of items to cull when the cache is full.
size_limit: Maximum cache size in bytes.
min_file_size: Minimum file size in bytes before storing as file. Otherwise store as blob in database.
sqlite_mmap_size: SQLite mmap size in bytes.
sqlite_cache_size: SQLite cache size in pages.
asset_handlers: List of AssetHandler instances for handling different data formats.
"""
def __init__(
self,
directory: str | None = None,
sqlite_timeout: int = 60,
dbname: str = "cache.db",
pickle_protocol: int = pickle.HIGHEST_PROTOCOL,
eviction_policy: t.Literal[
"none",
"least-recently-stored",
"least-recently-used",
"least-frequently-used",
] = "least-recently-stored",
cull_limit: int = 10,
size_limit: int = 2**30, # 1GB
min_file_size: int = 2**15, # 32KB
sqlite_mmap_size: int = 2**26, # 64MB
sqlite_cache_size: int = 2**13, # 8,192 pages
asset_handlers: t.List[AssetHandler] | None = None,
) -> None:
if directory is None:
directory = tempfile.mkdtemp(prefix="webcache-")
directory = str(directory)
directory = os.path.expanduser(directory)
directory = os.path.expandvars(directory)
self.directory = directory
self.dbname = dbname
self.sqlite_timeout = sqlite_timeout
self._local = threading.local()
self._txn_id = None
self.pickle_protocol = pickle_protocol
self.eviction_policy = eviction_policy
self.cull_limit = cull_limit
self.size_limit = size_limit
self.min_file_size = min_file_size
self.sqlite_mmap_size = sqlite_mmap_size
self.sqlite_cache_size = sqlite_cache_size
# Initialize asset handlers
if asset_handlers is None:
# Default handlers
self.asset_handlers = [
JSONHandler(),
ArrowHandler(),
PolarsHandler(),
PandasHandler(),
NumpyHandler(),
PickleHandler(), # Pickle last as fallback
]
else:
self.asset_handlers = asset_handlers
# Ensure PickleHandler is always available as fallback
if not any(isinstance(h, PickleHandler) for h in self.asset_handlers):
self.asset_handlers.append(PickleHandler())
if not os.path.isdir(directory):
try:
os.makedirs(directory, 0o755)
except OSError as error:
if error.errno != errno.EEXIST:
raise EnvironmentError(
error.errno,
'Cache directory "%s" does not exist and could not be created' % self.directory,
) from None
con = self.connect()
# Set sqlite wal journal mode, auto vacuum, and mmap size, synchronous, cache size
con.execute("PRAGMA journal_mode = WAL")
con.execute("PRAGMA auto_vacuum = FULL")
con.execute("PRAGMA mmap_size = %d" % self.sqlite_mmap_size)
con.execute("PRAGMA synchronous = NORMAL")
con.execute("PRAGMA cache_size = %d" % self.sqlite_cache_size)
((self._page_size,),) = con.execute("PRAGMA page_size").fetchall()
con.execute("CREATE TABLE IF NOT EXISTS Settings (key TEXT NOT NULL UNIQUE, value)")
con.execute("INSERT OR REPLACE INTO Settings VALUES (?, ?)", ("size", 0))
con.execute("""CREATE TABLE IF NOT EXISTS Cache (
rowid INTEGER PRIMARY KEY,
key BLOB,
raw INTEGER,
store_time REAL,
expire_time REAL,
access_time REAL,
access_count INTEGER DEFAULT 0,
size INTEGER DEFAULT 0,
mode INTEGER DEFAULT 0,
filename TEXT,
value BLOB
)""")
con.execute("CREATE UNIQUE INDEX IF NOT EXISTS Cache_key_raw ON Cache(key, raw)")
con.execute("CREATE INDEX IF NOT EXISTS Cache_expire_time ON Cache (expire_time) WHERE expire_time IS NOT NULL")
# Use triggers to keep size metadata up to date
con.execute("""CREATE TRIGGER IF NOT EXISTS Settings_size_insert
AFTER INSERT ON Cache FOR EACH ROW BEGIN
UPDATE Settings SET value = value + NEW.size
WHERE key = "size"; END""")
con.execute("""CREATE TRIGGER IF NOT EXISTS Settings_size_update
AFTER UPDATE ON Cache FOR EACH ROW BEGIN
UPDATE Settings
SET value = value + NEW.size - OLD.size
WHERE key = "size"; END""")
con.execute("""CREATE TRIGGER IF NOT EXISTS Settings_size_delete
AFTER DELETE ON Cache FOR EACH ROW BEGIN
UPDATE Settings SET value = value - OLD.size
WHERE key = "size"; END""")
query = EVICTION_POLICY[self.eviction_policy]["init"]
if query is not None:
con.execute(query)
def connect(self):
local_pid = getattr(self._local, "pid", None)
pid = os.getpid()
if local_pid != pid:
self.close()
self._local.pid = pid
con = getattr(self._local, "con", None)
if con is None:
con = self._local.con = sqlite3.connect(
os.path.join(self.directory, self.dbname),
timeout=self.sqlite_timeout,
isolation_level=None,
)
return con
def close(self):
con: sqlite3.Connection = getattr(self._local, "con", None) # type: ignore[no-untyped-call]
if con is None:
return
con.close()
try:
delattr(self._local, "con")
except AttributeError:
pass
def _disk_remove(self, file_path):
"""Remove a file given by `file_path` with cross-thread and cross-process safety."""
full_path = os.path.join(self.directory, file_path)
full_dir, _ = os.path.split(full_path)
with contextlib.suppress(OSError):
os.remove(full_path)
with contextlib.suppress(OSError):
os.removedirs(full_dir)
@contextlib.contextmanager
def transact(self, retry=False, filename=None) -> t.Iterator[t.Tuple[sqlite3.Connection, t.Callable]]:
"""Transaction context manager locking the cache.
Args:
retry: whether to retry the transaction if it fails.
filename: filename to remove if the transaction fails.
Raises:
CacheDatabaseTransactionTimeout: if the transaction times out.
Example:
Wrap a block of code in a transaction:
>>> with cache.transact() as (con, _):
... con.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY, name TEXT)")
... con.execute("INSERT INTO test (name) VALUES (?)", ("Alice",))
"""
con: sqlite3.Connection = self.connect()
filenames = []
_disk_remove = self._disk_remove
tid = threading.get_ident()
txn_id = self._txn_id
if tid == txn_id:
begin = False
else:
while True:
try:
con.execute("BEGIN IMMEDIATE")
begin = True
self._txn_id = tid
break
except sqlite3.OperationalError:
if retry:
continue
if filename is not None:
_disk_remove(filename)
raise TimeoutError from None
try:
yield con, filenames.append
except BaseException:
if begin:
assert self._txn_id == tid
self._txn_id = None
con.execute("ROLLBACK")
raise
else:
if begin:
assert self._txn_id == tid
self._txn_id = None
con.execute("COMMIT")
for name in filenames:
if name is not None:
_disk_remove(name)
def setting(self, key: str, value: t.Any = ENOVAL, update: bool = True) -> t.Any:
"""Get or set a setting in the cache."""
con = self.connect()
if value is ENOVAL:
select = "SELECT value FROM Settings WHERE key = ?"
((value,),) = sqlite_execute_with_retry(con, select, (key,)).fetchall()
return value
if update:
statement = "UPDATE Settings SET value = ? WHERE key = ?"
sqlite_execute_with_retry(con, statement, (value, key))
def volume(self) -> int:
"""Return estimated total size of cache on disk in bytes."""
con = self.connect()
((page_count,),) = con.execute("PRAGMA page_count").fetchall()
total_size = self._page_size * page_count + self.setting("size")
return total_size
def _disk_put(self, key):
"""Convert `key` to fields key and raw for Cache table.
Args:
key: key to be stored in cache.
Returns:
Tuple[sqlite3.Binary, bool]: a tuple of the key and a boolean indicating whether the key is a byte string.
"""
# pylint: disable=unidiomatic-typecheck
type_key = type(key)
if type_key is bytes:
return sqlite3.Binary(key), True
elif (
(type_key is str)
or (type_key is int and -9223372036854775808 <= key <= 9223372036854775807)
or (type_key is float)
):
return key, True
else:
#
data = pickle.dumps(key, protocol=self.pickle_protocol)
result = pickletools.optimize(data)
return sqlite3.Binary(result), False
def touch(self, key: str, expire: float | None = None, retry: bool = False) -> bool:
"""Touch `key` in cache and update `expire` time."""
now = time.time()
db_key, raw = self._disk_put(key)
expire_time = None if expire is None else now + expire
with self.transact(retry) as (con, _):
rows = con.execute(
"SELECT rowid, expire_time FROM Cache WHERE key = ? AND raw = ?",
(db_key, raw),
).fetchall()
if rows:
((rowid, old_expire_time),) = rows
if old_expire_time is None or old_expire_time > now:
con.execute(
"UPDATE Cache SET expire_time = ? WHERE rowid = ?",
(expire_time, rowid),
)
return True
return False
def _disk_fetch(self, mode: int, filename: str, value: t.Any, read: bool):
"""Convert fields `mode`, `filename`, and `value` from Cache table to value.
If mode is MODE_RAW, return value as bytes. If mode is MODE_BINARY and read is true, return value as file handle,
otherwise return value as bytes. If mode is MODE_TEXT, return value as string. If mode is MODE_PICKLE, read value
as pickle and return the result.
Args:
mode: mode of the value. Options are MODE_RAW, MODE_BINARY, MODE_TEXT, MODE_PICKLE.
filename: filename of the value.
value: value to be fetched.
read: whether to read the value.
Returns:
Any: the fetched value as str, bytes, file handle, or any other type if mode is MODE_PICKLE.
"""
if mode == MODE_RAW:
return bytes(value) if type(value) is sqlite3.Binary else value
elif mode == MODE_BINARY:
if read:
return open(os.path.join(self.directory, filename), "rb")
else:
with open(os.path.join(self.directory, filename), "rb") as reader:
return reader.read()
elif mode == MODE_TEXT:
full_path = os.path.join(self.directory, filename)
with open(full_path, "r", encoding="UTF-8") as reader:
return reader.read()
elif mode == MODE_PICKLE:
if value is None:
with open(os.path.join(self.directory, filename), "rb") as reader:
return pickle.load(reader)
else:
return pickle.load(io.BytesIO(value))
def get(
self,
key: str,
default: t.Any = None,
read: bool = False,
return_expire_time: bool = False,
retry: bool = False,
):
"""Get `key` from cache.
Args:
key: key to be retrieved from cache.
default: default value to return if key is not found.
read: return file handle to value.
expire_time: return expire time.
retry: whether to retry on database lock.
"""
db_key, raw = self._disk_put(key)
update_column: str = EVICTION_POLICY[self.eviction_policy]["get"]
select = """SELECT rowid, expire_time, mode, filename, value
FROM Cache WHERE key = ? AND raw = ?
AND (expire_time IS NULL OR expire_time > ?)
"""
if return_expire_time:
default = (default, None)
with self.transact(retry) as (con, _):
rows = con.execute(select, (db_key, raw, time.time())).fetchall()
if not rows:
return default
((rowid, db_expire_time, mode, filename, db_value),) = rows
try:
value = self._disk_fetch(mode, filename, db_value, read)
except IOError:
# Key was deleted before we could retrieve result.
return default
if update_column is not None:
now = time.time()
update = "UPDATE Cache SET %s WHERE rowid = ?"
con.execute(update % update_column.format(now=now), (rowid,))
if return_expire_time:
return value, db_expire_time
return value
def __getitem__(self, key):
"""Return corresponding value for `key` from cache."""
value = self.get(key, default=ENOVAL, retry=True)
if value is ENOVAL:
raise KeyError(key)
return value
def read(self, key, retry=False):
"""Return file handle value corresponding to `key` from cache."""
handle = self.get(key, default=ENOVAL, read=True, retry=retry)
if handle is ENOVAL:
raise KeyError(key)
return handle
def __contains__(self, key: str) -> bool:
"""Return `True` if `key` matching item is found in cache."""
con = self.connect()
db_key, raw = self._disk_put(key)
select = "SELECT rowid FROM Cache WHERE key = ? AND raw = ? AND (expire_time IS NULL OR expire_time > ?)"
rows = con.execute(select, (db_key, raw, time.time())).fetchall()
return bool(rows)
def exists(self, key: str) -> bool:
"""Return `True` if `key` matching item is found in cache."""
return key in self
def _disk_filename(self, key: t.Any = UNKNOWN, value: t.Any = UNKNOWN):
"""Return filename and full-path tuple for file storage."""
hex_name = codecs.encode(os.urandom(16), "hex").decode("utf-8")
sub_dir = os.path.join(hex_name[:2], hex_name[2:4])
name = hex_name[4:] + ".val"
filename = os.path.join(sub_dir, name)
full_path = os.path.join(self.directory, filename)
return filename, full_path
def _safe_filename(self, base_filename: str, extension: str, max_length: int = 255) -> str:
"""Create a safe filename that doesn't exceed filesystem limits.
Args:
base_filename: Base filename without extension
extension: File extension (without dot)
max_length: Maximum filename length (default 255 for most filesystems)
Returns:
Safe filename that fits within length limits
"""
full_name = f"{base_filename}.{extension}"
# If filename is already short enough, return as-is
if len(full_name) <= max_length:
return full_name
# Calculate how much we need to truncate
# Reserve space for extension and a hash separator
available_length = max_length - len(extension) - 1 - 8 # 8 chars for hash
if available_length <= 0:
# Extension is too long, use only hash
hash_str = hashlib.md5(base_filename.encode()).hexdigest()[:8]
return f"{hash_str}.{extension}"
# Truncate base and add hash to maintain uniqueness
truncated_base = base_filename[:available_length]
hash_str = hashlib.md5(base_filename.encode()).hexdigest()[:8]
return f"{truncated_base}_{hash_str}.{extension}"
def _disk_store(self, value, read, key=UNKNOWN):
"""Convert `value` to fields size, mode, filename, and value for Cache table.
Args:
value: value to convert
read: True when value is file-like object
key: key for item (default UNKNOWN)
Returns:
(size, mode, filename, value) tuple for Cache table
"""
# pylint: disable=unidiomatic-typecheck
type_value = type(value)
min_file_size = self.min_file_size
if (
(type_value is str and len(value) < min_file_size)
or (type_value is int and -9223372036854775808 <= value <= 9223372036854775807)
or (type_value is float)
):
return 0, MODE_RAW, None, value
elif type_value is bytes:
if len(value) < min_file_size:
return 0, MODE_RAW, None, sqlite3.Binary(value)
else:
filename, full_path = self._disk_filename(key, value)
self._disk_write(full_path, io.BytesIO(value), "xb")
return len(value), MODE_BINARY, filename, None
elif type_value is str:
filename, full_path = self._disk_filename(key, value)
self._disk_write(full_path, io.StringIO(value), "x", "UTF-8")
size = os.path.getsize(full_path)
return size, MODE_TEXT, filename, None
elif read:
reader = partial(value.read, 2**22)
filename, full_path = self._disk_filename(key, value)
iterator = iter(reader, b"")
size = self._disk_write(full_path, iterator, "xb")
return size, MODE_BINARY, filename, None
else:
result = pickle.dumps(value, protocol=self.pickle_protocol)
if len(result) < min_file_size:
return 0, MODE_PICKLE, None, sqlite3.Binary(result)
else:
filename, full_path = self._disk_filename(key, value)
self._disk_write(full_path, io.BytesIO(result), "xb")
return len(result), MODE_PICKLE, filename, None
def _disk_write(self, full_path, iterator, mode, encoding=None):
full_dir, _ = os.path.split(full_path)
for count in range(1, 11):
try:
# Ensure directory exists - use exist_ok to handle race conditions
os.makedirs(full_dir, exist_ok=True)
# Try to open the file - if directory was deleted, this will fail
with open(full_path, mode, encoding=encoding) as writer:
size = 0
for chunk in iterator:
size += len(chunk)
writer.write(chunk)
return size
except (OSError, IOError) as e:
# Handle various filesystem errors including permission issues,
# directory deletion, disk full, etc.
if count == 10:
# Give up after 10 tries
raise OSError(f"Failed to write file after {count} attempts: {e}") from e
# Clean up partial file if it exists
with contextlib.suppress(OSError):
if os.path.exists(full_path):
os.remove(full_path)
# Brief delay before retry to allow for transient conditions to resolve
time.sleep(0.001 * count) # Exponential backoff
continue
def set(self, key, value, expire=None, read=False, retry=False):
"""Set corresponding `value` for `key` in cache
Args:
key: key name
value: value to store
expire: expire time in seconds.
read: whether to read the value.
retry: whether to retry on database lock.
"""
now = time.time()
db_key, raw = self._disk_put(key)
expire_time = None if expire is None else now + expire
size, mode, filename, db_value = self._disk_store(value, read, key=key)
columns = (expire_time, size, mode, filename, db_value)
with self.transact(retry, filename) as (con, cleanup):
rows = con.execute(
"SELECT rowid, filename FROM Cache WHERE key = ? AND raw = ?",
(db_key, raw),
).fetchall()
if rows:
((rowid, old_filename),) = rows
cleanup(old_filename)
self._row_update(rowid, now, columns)
else:
self._row_insert(db_key, raw, now, columns)
self._cull(now, con, cleanup)
return True
def __setitem__(self, key, value):
"""Set corresponding `value` for `key` in cache."""
self.set(key, value, retry=True)
def _row_update(self, rowid, now, columns):
con = self.connect()
expire_time, size, mode, filename, value = columns
con.execute(
"""UPDATE Cache SET
store_time = ?,
expire_time = ?,
access_time = ?,
access_count = ?,
size = ?,
mode = ?,
filename = ?,
value = ?
WHERE rowid = ?
""",
(
now, # store_time
expire_time,
now, # access_time
0, # access_count
size,
mode,
filename,
value,
rowid,
),
)
def _row_insert(self, key, raw, now, columns):
con = self.connect()
expire_time, size, mode, filename, value = columns
con.execute(
"""INSERT INTO Cache(
key, raw, store_time, expire_time, access_time,
access_count, size, mode, filename, value
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
key,
raw,
now, # store_time
expire_time,
now, # access_time
0, # access_count
size,
mode,
filename,
value,
),
)
def _cull(self, now: int | float, con: sqlite3.Connection, cleanup: t.Callable, limit: int | None = None):
cull_limit = self.cull_limit if limit is None else limit
if cull_limit == 0:
return
# Evict expired keys
select_expired_template = (
"SELECT %s FROM Cache WHERE expire_time IS NOT NULL AND expire_time < ? ORDER BY expire_time LIMIT ?"
)
select_expired = select_expired_template % "filename"
rows = con.execute(select_expired, (now, cull_limit)).fetchall()
if rows:
delete_expired = "DELETE FROM Cache WHERE rowid IN (%s)" % (select_expired_template % "rowid")
con.execute(delete_expired, (now, cull_limit))
for (filename,) in rows:
cleanup(filename)
cull_limit -= len(rows)
if cull_limit == 0:
return
# Evict keys by policy
select_policy = EVICTION_POLICY[self.eviction_policy]["cull"]
if select_policy is None or self.volume() < self.size_limit:
return
select_filename = select_policy.format(fields="filename", now=now)
rows = con.execute(select_filename, (cull_limit,)).fetchall()
if rows:
delete = "DELETE FROM Cache WHERE rowid IN (%s)" % (select_policy.format(fields="rowid", now=now))
con.execute(delete, (cull_limit,))
for (filename,) in rows:
cleanup(filename)
def clear(self, *args, retry=False):
"""Remove items from cache.
Args:
*args: Optional arguments to specify what to clear:
- No args: Clear all items (default behavior)
- str: Clear items matching the key name
- callable: Clear items for memoized/asset decorated functions
- list/tuple: Clear items matching multiple keys/functions
retry: Whether to retry on database lock
Returns:
int: Number of items cleared
Examples:
>>> cache = Cache()
>>> cache.clear() # Clear all items
>>> cache.clear("my_key") # Clear specific key
>>> cache.clear(my_func) # Clear memoized function cache
>>> cache.clear("key1", "key2", my_func) # Clear multiple items
"""
if not args:
# Clear all items (original behavior)
select = "SELECT rowid, filename FROM Cache WHERE rowid > ? ORDER BY rowid LIMIT ?"
select_args = [0, 100]
return self._select_delete(select, select_args, retry=retry)
# Build list of keys to clear
keys_to_clear = []
for arg in args:
if isinstance(arg, str):
# Direct key name
keys_to_clear.append(arg)
elif callable(arg):
# Memoized or asset decorated function
if hasattr(arg, "__cache_key__"):
# This is a decorated function - we need to clear all its cached results
# We'll use a pattern match on the function's full name
func_name = full_name(arg)
keys_to_clear.append(func_name)
else:
# Regular function, use its full name
keys_to_clear.append(full_name(arg))
elif isinstance(arg, (list, tuple)):
# Recursively process lists/tuples
for item in arg:
if isinstance(item, str):
keys_to_clear.append(item)
elif callable(item):
if hasattr(item, "__cache_key__"):
keys_to_clear.append(full_name(item))
else:
keys_to_clear.append(full_name(item))
else:
# Convert other types to string keys
keys_to_clear.append(str(arg))
if not keys_to_clear:
return 0
return self._clear_specific_keys(keys_to_clear, retry=retry)
def _clear_specific_keys(self, keys: t.List[str], retry: bool = False) -> int:
"""Clear cache entries for specific keys or key patterns.
Args:
keys: List of keys or function names to clear
retry: Whether to retry on database lock
Returns:
Number of items cleared
"""
total_cleared = 0
for key_pattern in keys:
# Handle exact key matches
db_key, raw = self._disk_put(key_pattern)
# Clear exact matches
select = "SELECT rowid, filename FROM Cache WHERE key = ? AND raw = ?"
try:
with self.transact(retry) as (con, cleanup):
rows = con.execute(select, (db_key, raw)).fetchall()
if rows:
rowids = [str(row[0]) for row in rows]
delete = f"DELETE FROM Cache WHERE rowid IN ({','.join(rowids)})"
con.execute(delete)
for row in rows:
cleanup(row[1]) # Clean up file
total_cleared += len(rows)
except TimeoutError:
pass
# Also clear function-based cache entries (for memoized/asset functions)
# These have keys that start with the function name tuple
try:
with self.transact(retry) as (con, cleanup):
# Get all non-raw cache entries and check them
select = "SELECT rowid, filename, key FROM Cache WHERE raw = 0"
rows = con.execute(select).fetchall()
matching_rows = []
for rowid, filename, key_blob in rows:
try:
# Deserialize the key to check if it matches our function
key_tuple = pickle.loads(bytes(key_blob))
if isinstance(key_tuple, tuple) and len(key_tuple) > 0:
if key_tuple[0] == key_pattern:
matching_rows.append((rowid, filename))
except (pickle.PickleError, TypeError, IndexError):
continue
if matching_rows:
rowids = [str(row[0]) for row in matching_rows]
delete = f"DELETE FROM Cache WHERE rowid IN ({','.join(rowids)})"
con.execute(delete)
for _, filename in matching_rows:
cleanup(filename)
total_cleared += len(matching_rows)
except (TimeoutError, sqlite3.Error):
pass
return total_cleared
def _select_delete(self, select, args, row_index=0, arg_index=0, retry=False):
count = 0
delete = "DELETE FROM Cache WHERE rowid IN (%s)"
try:
while True:
with self.transact(retry) as (con, cleanup):
rows = con.execute(select, args).fetchall()
if not rows:
break
count += len(rows)
con.execute(delete % ",".join(str(row[0]) for row in rows))
for row in rows:
args[arg_index] = row[row_index]
cleanup(row[-1])
except TimeoutError:
raise TimeoutError(count) from None
return count
def cleanup(self, force: bool = False, cache_timeout: int = 3600):
"""Remove all cache files that are older than `cache_timeout` seconds.
Args:
force (bool): If True, delete all cache files regardless of age.
cache_timeout (int): Cache timeout in seconds.
"""
for path in Path(self.directory).iterdir():
if force:
delete(Path(path))
continue
age = time.time() - os.stat(path).st_mtime
if age > cache_timeout:
delete(Path(path))
def memoize(self, name=None, typed=False, expire=None, ignore=()) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
"""Memoizing cache decorator.
Decorator to wrap callable with memoizing function using cache.
Repeated calls with the same arguments will lookup result in cache and
avoid function evaluation.
If name is set to None (default), the callable name will be determined
automatically.
When expire is set to zero, function results will not be set in the
cache. Cache lookups still occur, however. Read
:doc:`case-study-landing-page-caching` for example usage.
If typed is set to True, function arguments of different types will be
cached separately. For example, f(3) and f(3.0) will be treated as
distinct calls with distinct results.
The original underlying function is accessible through the __wrapped__
attribute. This is useful for introspection, for bypassing the cache,
or for rewrapping the function with a different cache.
>>> cache = Cache()
>>> @cache.memoize(expire=1, tag='fib')
... def fibonacci(number):
... if number == 0:
... return 0
... elif number == 1:
... return 1
... else:
... return fibonacci(number - 1) + fibonacci(number - 2)
>>> print(fibonacci(100))
354224848179261915075
An additional `__cache_key__` attribute can be used to generate the
cache key used for the given arguments.
>>> key = fibonacci.__cache_key__(100)
>>> print(cache[key])
354224848179261915075
Remember to call memoize when decorating a callable. If you forget,
then a TypeError will occur. Note the lack of parenthenses after
memoize below:
>>> @cache.memoize
... def test():
... pass
Traceback (most recent call last):
...
TypeError: name cannot be callable
Args:
cache: cache to store callable arguments and return values
str name: name given for callable (default None, automatic)
bool typed: cache different types separately (default False)
float expire: seconds until arguments expire
(default None, no expiry)
str tag: text to associate with arguments (default None)
set ignore: positional or keyword args to ignore (default ())
Returns:
callable decorator
"""
if callable(name):
raise TypeError("name cannot be callable")
def decorator(func: t.Callable[P, T]) -> t.Callable[P, T]:
"""Decorator created by memoize() for callable `func`."""
base = (full_name(func),) if name is None else (name,)
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrapper for callable to cache arguments and return values."""
key = wrapper.__cache_key__(*args, **kwargs) # type: ignore[no-untyped-call]
result = self.get(key, default=ENOVAL, retry=True)
if result is ENOVAL:
result = func(*args, **kwargs)
if expire is None or expire > 0:
self.set(key, result, expire, retry=True)
return result
return t.cast(T, result)
def __cache_key__(*args, **kwargs):
"""Make key for cache given function arguments."""
return args_to_key(base, args, kwargs, typed, ignore)
wrapper.__cache_key__ = __cache_key__ # type: ignore[no-untyped-call]
return t.cast(t.Callable[P, T], wrapper)
return decorator
def register_asset_handler(self, handler: AssetHandler, prepend: bool = False) -> None:
"""Register a new asset handler.
Args:
handler: The AssetHandler instance to register
prepend: If True, add to beginning of list (higher priority)
"""
if prepend:
self.asset_handlers.insert(0, handler)
else:
# Insert before PickleHandler if it exists
pickle_idx = None
for i, h in enumerate(self.asset_handlers):
if isinstance(h, PickleHandler):
pickle_idx = i
break
if pickle_idx is not None:
self.asset_handlers.insert(pickle_idx, handler)
else:
self.asset_handlers.append(handler)
def _get_asset_handler(self, data: t.Any, format: str | None = None) -> AssetHandler:
"""Get the appropriate asset handler for the data.
Args:
data: The data to find a handler for
format: Optional format hint
Returns:
AssetHandler: The handler that can process this data
"""
# If format is specified, try to find handler with that format name
if format:
for handler in self.asset_handlers:
if handler.format_name == format and handler.is_available():
return handler
# Otherwise, find first handler that can handle the data
for handler in self.asset_handlers:
if handler.is_available() and handler.can_handle(data):
return handler
# This should never happen if PickleHandler is in the list
raise ValueError("No handler found for data type")
def _store_asset(self, key: str, data: t.Any, expire: float | None, format: str | None = None) -> None:
"""Store an asset using the appropriate handler.
Args:
key: Cache key
data: Data to store
expire: Expiration time in seconds
format: Optional format hint
"""
handler = self._get_asset_handler(data, format)
# Generate filename with format extension and data type info to avoid collisions
# Include data type information in the filename to prevent collisions
data_type_info = f"{type(data).__module__}.{type(data).__name__}"
# Create a hash of the data type to keep filename manageable
type_hash = hashlib.md5(data_type_info.encode()).hexdigest()[:8]
filename, full_path = self._disk_filename(key, str(data)[:100]) # Use truncated str for filename
# Use safe filename to ensure it doesn't exceed filesystem limits
# Only apply to the actual filename part, not the directory path
dir_part, name_part = os.path.split(filename)
name_base, _ = name_part.rsplit(".", 1)
safe_name = self._safe_filename(f"{name_base}_{type_hash}", handler.format_name)
filename = os.path.join(dir_part, safe_name)
dir_part, name_part = os.path.split(full_path)
full_path = os.path.join(dir_part, safe_name)
# Save using handler
full_dir, _ = os.path.split(full_path)
os.makedirs(full_dir, exist_ok=True)
try:
handler.save(data, full_path)
# For numpy handler, the actual file created might have .npy extension
actual_path = full_path
if isinstance(handler, NumpyHandler) and not os.path.exists(full_path):
actual_path = full_path + '.npy'
size = os.path.getsize(actual_path)
except Exception as e:
# Clean up on failure - check both possible paths
for path_to_clean in [full_path, full_path + '.npy']:
if os.path.exists(path_to_clean):
os.remove(path_to_clean)
raise e
# Store metadata in database
now = time.time()
db_key, raw = self._disk_put(key)
expire_time = None if expire is None else now + expire
# Update filename to reflect the actual file created (for numpy handler)
if isinstance(handler, NumpyHandler) and not os.path.exists(os.path.join(self.directory, filename)):
filename = filename + '.npy'
with self.transact(retry=True, filename=filename) as (con, cleanup):
# Check if key already exists
rows = con.execute(
"SELECT rowid, filename FROM Cache WHERE key = ? AND raw = ?",
(db_key, raw),
).fetchall()
if rows:
((rowid, old_filename),) = rows
cleanup(old_filename)
# Update existing entry
con.execute(
"""UPDATE Cache SET
store_time = ?, expire_time = ?, access_time = ?, access_count = ?,
size = ?, mode = ?, filename = ?, value = ?
WHERE rowid = ?""",
(now, expire_time, now, 0, size, MODE_BINARY, filename, None, rowid),
)
else:
# Insert new entry
con.execute(
"""INSERT INTO Cache(
key, raw, store_time, expire_time, access_time,
access_count, size, mode, filename, value
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(db_key, raw, now, expire_time, now, 0, size, MODE_BINARY, filename, None),
)
self._cull(now, con, cleanup)
def _load_asset(self, path: str, format: str | None = None) -> t.Any:
"""Load an asset from file.
Args:
path: File path
format: Optional format hint
Returns:
The loaded data
"""
# Check if file exists
if not os.path.exists(path):
raise FileNotFoundError(f"Asset file not found: {path}")
# Determine format from file extension if not provided
if format is None:
_, ext = os.path.splitext(path)
format = ext.lstrip(".")
# Try to load with the specified/detected format handler
format_handler = None
for handler in self.asset_handlers:
if handler.format_name == format and handler.is_available():
format_handler = handler
break
if format_handler:
try:
return format_handler.load(path)
except Exception as e:
# If specific format handler fails, try fallback strategies
if not isinstance(format_handler, PickleHandler):
# Try pickle fallback if the original format wasn't pickle
pickle_handler = None
for handler in self.asset_handlers:
if isinstance(handler, PickleHandler):
pickle_handler = handler
break
if pickle_handler:
try:
return pickle_handler.load(path)
except Exception:
# If pickle also fails, raise the original error
raise ValueError(f"Failed to load asset with {format} format: {e}") from e
# If no fallback or fallback failed, raise original error
raise ValueError(f"Failed to load asset with {format} format: {e}") from e
# No handler found for the format - try pickle as last resort
pickle_handler = None
for handler in self.asset_handlers:
if isinstance(handler, PickleHandler):
pickle_handler = handler
break
if pickle_handler:
try:
return pickle_handler.load(path)
except Exception as e:
raise ValueError(f"No handler found for format '{format}' and pickle fallback failed: {e}") from e
raise ValueError(f"No handler found for format: {format}")
def asset(
self,
name: str | None = None,
typed: bool | None = False,
expire: int | float | None = None,
ignore: tuple = (),
format=None,
) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
"""Asset caching decorator.
Decorator to wrap callable with asset caching function.
Similar to memoize but stores data in native formats instead of pickle.
Supports various data types through duck typing:
- PyArrow Tables: stored as Arrow IPC files
- Polars DataFrames: stored as Parquet files
- Pandas DataFrames: stored as Parquet files
- NumPy arrays: stored as .npy files
- Dicts/Lists: stored as JSON or pickle based on content
- Other types: fallback to pickle
Args:
name: name given for callable (default None, automatic)
typed: cache different types separately (default False)
expire: seconds until arguments expire (default None, no expiry)
ignore: positional or keyword args to ignore (default ())
format: force specific format (arrow, parquet, numpy, json, pickle)
Returns:
callable decorator
"""
if callable(name):
raise TypeError("name cannot be callable")
def decorator(func: t.Callable[P, T]) -> t.Callable[P, T]:
"""Decorator created by asset() for callable `func`."""
base = (full_name(func),) if name is None else (name,)
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrapper for callable to cache asset arguments and return values."""
key = wrapper.__cache_key__(*args, **kwargs) # type: ignore[attr-defined]
# Check if asset exists in cache
con = self.connect()
db_key, raw = self._disk_put(key)
select = """SELECT filename, mode FROM Cache WHERE key = ? AND raw = ?
AND (expire_time IS NULL OR expire_time > ?)"""
rows = con.execute(select, (db_key, raw, time.time())).fetchall()
if rows:
((filename, _),) = rows # mode not needed here, renamed to _
if filename:
# Load from file using appropriate loader
full_path = os.path.join(self.directory, filename)
return t.cast(T, self._load_asset(full_path, format))
# Asset not in cache, compute it
result = func(*args, **kwargs)
if expire is None or expire > 0:
# Store asset in appropriate format
self._store_asset(key, result, expire, format)
return result
def __cache_key__(*args, **kwargs):
"""Make key for cache given function arguments."""
return args_to_key(base, args, kwargs, typed, ignore)
wrapper.__cache_key__ = __cache_key__ # type: ignore[attr-defined]
return t.cast(t.Callable[P, T], wrapper)
return decorator