from collections.abc import MutableMapping
import numpy as np
from typing import Optional, Dict, List, Union, Tuple
from numpy.typing import NDArray
import os
import tempfile
import h5py
from typing import Any
from loguru import logger
from ..convenience import ByteSize
[docs]
class EyeDataDict(MutableMapping):
"""
A dictionary that contains 1-dimensional ndarrays of equal length
and with the same datatype (float).
Drops empty entries (None or length-0 ndarrays).
Keys stored in this dictionary have the following shape:
(eye)_(variable) where eye can be any string identifier (e.g., "left", "right",
"mean", "median", "regress", or any other custom identifier) and variable is any
string identifier (e.g., "x", "y", "pupil", "baseline", "response", or any other
custom identifier).
The dictionary can be indexed by a string "eye_variable" or by a tuple
("eye", "variable") like data["left","x"] or data["left_x"].
"""
def __init__(self, *args, **kwargs) -> None:
self.data: Dict[str, NDArray] = dict()
self.mask: Dict[str, NDArray] = dict() # mask for missing/artifactual values
self.length: int = 0
self.shape: Optional[Tuple[int, ...]] = None
self.update(dict(*args, **kwargs)) # use the free update to set keys
[docs]
def get_available_eyes(self, variable=None):
"""
Return a list of available eyes.
Parameters
----------
variable : str, optional
If specified, return only eyes for this variable.
"""
if variable is not None:
eyes=[k.split("_")[0] for k in self.data.keys() if k.endswith("_"+variable)]
else:
eyes=[k.split("_")[0] for k in self.data.keys()]
return list(set(eyes))
[docs]
def get_available_variables(self):
"""
Return a list of available variables.
"""
variables=[k.split("_")[1] for k in self.data.keys()]
return list(set(variables))
[docs]
def get_eye(self, eye: str) -> 'EyeDataDict':
"""
Return a subset EyeDataDict with all variables for a given eye.
Parameters
----------
eye : str
The eye to get data for ('left', 'right', 'mean', etc.)
Returns
-------
EyeDataDict
A new EyeDataDict containing only data for the specified eye
Examples
--------
>>> d = EyeDataDict(left_x=[1,2], left_y=[3,4], right_x=[5,6])
>>> left_data = d.get_eye('left')
>>> print(left_data.data.keys())
dict_keys(['left_x', 'left_y'])
"""
return EyeDataDict({k:v for k,v in self.data.items() if k.startswith(eye+"_")})
[docs]
def get_variable(self, variable):
"""
Return a subset EyeDataDict with all eyes for a given variable.
"""
return EyeDataDict({k:v for k,v in self.data.items() if k.endswith("_"+variable)})
def __contains__(self, key):
return key in self.data
def __getitem__(self, key):
if isinstance(key, tuple):
key = "_".join(key)
return self.data[key]
def __setitem__(self, key: str, value: NDArray) -> None:
if value is None or len(value) == 0:
return
value = np.array(value)
if self.length > 0 and self.shape is not None:
if value.shape != self.shape:
raise ValueError(
f"Array must have shape {self.shape}, got {value.shape}"
)
if self.length==0 or self.shape is None:
self.length=value.shape[0]
self.shape=value.shape
if np.any(np.array(self.shape)!=np.array(value.shape)):
raise ValueError("Array must have same dimensions as existing arrays")
key = self._validate_key(key) # Only validate key when setting values
self.data[key] = value.astype(float)
self.mask[key] = np.zeros(self.shape, dtype=int)
def __delitem__(self, key):
if isinstance(key, tuple):
key = "_".join(key)
del self.data[key]
del self.mask[key]
def __iter__(self):
return iter(self.data)
def __len__(self):
return self.length
def __repr__(self) -> str:
r="EyeDataDict(vars=%i,n=%i,shape=%s): \n"%(len(self.data), self.length, str(self.shape))
for k,v in self.data.items():
r+=" %s (%s): "%(k,v.dtype)
r+=", ".join(v.flat[0:(min(5,self.length))].astype(str).tolist())
if self.length>5:
r+="..."
r+="\n"
return r
def _validate_key(self, key: Union[str, Tuple[str, str]]) -> str:
"""Validate and normalize key format."""
if not isinstance(key, (str, tuple)):
raise TypeError("Key must be string or tuple")
if isinstance(key, tuple):
if len(key) != 2:
raise ValueError("Tuple key must have exactly 2 elements (eye, variable)")
key = "_".join(key)
if "_" not in key:
raise ValueError("Key must be in format 'eye_variable'")
return key
[docs]
def copy(self) -> 'EyeDataDict':
"""Create a deep copy of the dictionary."""
new_dict = EyeDataDict()
for key, value in self.data.items():
new_dict[key] = value.copy()
return new_dict
[docs]
def set_mask(self, key: str, mask: NDArray) -> None:
"""Set mask for a specific key."""
if key not in self.data:
raise KeyError(f"No data for key {key}")
if mask.shape != self.shape:
raise ValueError(f"Mask must have shape {self.shape}")
self.mask[key] = mask.astype(int)
[docs]
def get_mask(self, key: str) -> NDArray:
"""Get mask for a specific key."""
return self.mask[key]
@property
def variables(self) -> List[str]:
"""List of all available variables."""
return self.get_available_variables()
@property
def eyes(self) -> List[str]:
"""List of all available eyes."""
return self.get_available_eyes()
[docs]
def get_size(self) -> ByteSize:
"""Return the size of the dictionary in bytes.
Returns
-------
ByteSize
Total size in bytes.
"""
total_size = 0
for arr in self.data.values():
total_size += arr.nbytes
for arr in self.mask.values():
total_size += arr.nbytes
return ByteSize(total_size)
class CachedEyeDataDict(EyeDataDict):
def __init__(self, *args, cache_dir: Optional[str] = None, max_memory_mb: float = 100, **kwargs):
"""Initialize a cached version of EyeDataDict.
Parameters
----------
cache_dir : str, optional
Directory to store cache files. If None, creates a temporary directory.
max_memory_mb : float, optional
Maximum memory usage in MB. Default is 100MB.
"""
# Initialize base class without data
self.data: Dict[str, np.ndarray] = {}
self.mask: Dict[str, np.ndarray] = {}
self.length: int = 0
self.shape: Optional[tuple] = None
# Cache settings
self._cache_dir = cache_dir or tempfile.mkdtemp()
if cache_dir is not None:
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"Created cache directory at {cache_dir}")
self._max_memory_bytes = max_memory_mb * 1024 * 1024
self._current_memory_bytes = 0
# Initialize single HDF5 file
self._init_h5_file()
# Track memory usage and access patterns
self._in_memory_data: Dict[str, np.ndarray] = {}
self._in_memory_mask: Dict[str, np.ndarray] = {}
self._array_sizes: Dict[str, int] = {}
self._access_counts: Dict[str, int] = {}
# Add any initial data
if args or kwargs:
self.update(dict(*args, **kwargs))
def _init_h5_file(self):
"""Initialize single HDF5 file with data and mask groups."""
if not hasattr(self, '_h5_file'):
self._h5_path = os.path.join(self._cache_dir, 'eyedata_cache.h5')
self._h5_file = h5py.File(self._h5_path, 'a')
# Create groups if they don't exist
if 'data' not in self._h5_file:
self._h5_file.create_group('data')
if 'mask' not in self._h5_file:
self._h5_file.create_group('mask')
def _get_array_size(self, arr: np.ndarray) -> int:
"""Calculate size of numpy array in bytes."""
return arr.nbytes
def _update_cache(self, key: str, data: np.ndarray, mask: np.ndarray):
"""Update memory cache using LRU strategy with size limits."""
total_size = self._get_array_size(data) + self._get_array_size(mask)
# If arrays are too large to fit in cache, don't cache them
if total_size > self._max_memory_bytes:
return
# Remove least recently used arrays until we have enough space
while (self._current_memory_bytes + total_size > self._max_memory_bytes and
self._in_memory_data):
# Find least recently used key
lru_key = min(self._access_counts.items(), key=lambda x: x[1])[0]
self._current_memory_bytes -= self._array_sizes[lru_key]
del self._in_memory_data[lru_key]
del self._in_memory_mask[lru_key]
del self._array_sizes[lru_key]
del self._access_counts[lru_key]
# Add new arrays to cache
self._in_memory_data[key] = data
self._in_memory_mask[key] = mask
self._array_sizes[key] = total_size
self._current_memory_bytes += total_size
self._access_counts[key] = max(self._access_counts.values(), default=0) + 1
def __setitem__(self, key: str, value: np.ndarray):
"""Set item in cache and HDF5."""
if value is None or len(value) == 0:
return
value = np.array(value)
if self.length > 0 and self.shape is not None:
if value.shape != self.shape:
raise ValueError(f"Array must have shape {self.shape}, got {value.shape}")
if self.length == 0 or self.shape is None:
self.length = value.shape[0]
self.shape = value.shape
if np.any(np.array(self.shape) != np.array(value.shape)):
raise ValueError("Array must have same dimensions as existing arrays")
key = self._validate_key(key)
value = value.astype(float)
mask = np.zeros(self.shape, dtype=int)
# Store in HDF5
if key in self._h5_file['data']:
del self._h5_file['data'][key]
if key in self._h5_file['mask']:
del self._h5_file['mask'][key]
self._h5_file['data'].create_dataset(key, data=value)
self._h5_file['mask'].create_dataset(key, data=mask)
self._h5_file.flush()
# Update memory cache
self._update_cache(key, value, mask)
# Update base class data
self.data[key] = value
self.mask[key] = mask
# Initialize access count for new key
self._access_counts[key] = max(self._access_counts.values(), default=0) + 1
def __getitem__(self, key: str) -> np.ndarray:
"""Get item from cache or disk."""
key = self._validate_key(key)
# Update access count for LRU
if key in self._access_counts:
self._access_counts[key] = max(self._access_counts.values()) + 1
# Try to get from memory cache first
if key in self._in_memory_data:
return self._in_memory_data[key]
# If not in memory, try to get from HDF5
if key in self._h5_file['data']:
data = self._h5_file['data'][key][:]
mask = self._h5_file['mask'][key][:] if key in self._h5_file['mask'] else np.zeros_like(data, dtype=int)
# Try to cache the data in memory
self._update_cache(key, data, mask)
return data
raise KeyError(key)
def get_mask(self, key: str) -> np.ndarray:
"""Get mask for a specific key."""
key = self._validate_key(key)
# Try memory cache first
if key in self._in_memory_mask:
return self._in_memory_mask[key]
# Load from HDF5
if key in self._h5_file['mask']:
mask = self._h5_file['mask'][key][:]
if key in self._h5_file['data']:
data = self._h5_file['data'][key][:]
self._update_cache(key, data, mask)
return mask
raise KeyError(key)
def set_mask(self, key: str, mask: np.ndarray):
"""Set mask for a specific key."""
key = self._validate_key(key)
mask = np.array(mask, dtype=int)
# Store in HDF5
if key in self._h5_file['mask']:
del self._h5_file['mask'][key]
self._h5_file['mask'].create_dataset(key, data=mask)
# Update memory cache if key is cached
if key in self._in_memory_data:
self._update_cache(key, self._in_memory_data[key], mask)
def set_max_memory(self, max_memory_mb: float):
"""Set maximum memory usage in MB."""
self._max_memory_bytes = max_memory_mb * 1024 * 1024
# If new limit is lower, remove excess arrays
while self._current_memory_bytes > self._max_memory_bytes and self._in_memory_data:
lru_key = min(self._access_counts.items(), key=lambda x: x[1])[0]
self._current_memory_bytes -= self._array_sizes[lru_key]
del self._in_memory_data[lru_key]
del self._in_memory_mask[lru_key]
del self._array_sizes[lru_key]
del self._access_counts[lru_key]
def get_cache_stats(self) -> Dict[str, Any]:
"""Get detailed cache statistics."""
# Calculate actual memory usage of in-memory arrays
memory_used = sum(arr.nbytes for arr in self._in_memory_data.values())
memory_used += sum(arr.nbytes for arr in self._in_memory_mask.values())
return {
'memory_used_mb': memory_used / (1024 * 1024),
'memory_limit_mb': self._max_memory_bytes / (1024 * 1024),
'arrays_in_memory': len(self._in_memory_data),
'arrays_on_disk': len(self._h5_file['data']),
'memory_usage_per_array': {
k: (self._in_memory_data[k].nbytes + self._in_memory_mask[k].nbytes) / (1024 * 1024)
for k in self._in_memory_data.keys()
}
}
def clear_cache(self):
"""Clear all cached data from memory and disk."""
# Clear memory cache
self._in_memory_data.clear()
self._in_memory_mask.clear()
self._array_sizes.clear()
self._access_counts.clear()
self._current_memory_bytes = 0
# Clear HDF5 file
if hasattr(self, '_h5_file'):
del self._h5_file['data']
del self._h5_file['mask']
self._h5_file.create_group('data')
self._h5_file.create_group('mask')
self._h5_file.flush()
# Clear base class data
self.data.clear()
self.mask.clear()
self.length = 0
self.shape = None
def __del__(self):
"""Clean up HDF5 file."""
if hasattr(self, '_h5_file'):
self._h5_file.close()
def get_size(self) -> ByteSize:
"""Return the size of the dictionary in bytes, split by storage location.
Returns
-------
ByteSize
Total size in bytes, with cached portion if applicable.
"""
memory_size = 0
for arr in self._in_memory_data.values():
memory_size += arr.nbytes
for arr in self._in_memory_mask.values():
memory_size += arr.nbytes
disk_size = 0
for key in self._h5_file['data'].keys():
if key not in self._in_memory_data: # Only count arrays not in memory
disk_size += self._h5_file['data'][key].nbytes
if key in self._h5_file['mask']:
disk_size += self._h5_file['mask'][key].nbytes
return ByteSize({
'memory': memory_size,
'disk': disk_size
})
def __repr__(self) -> str:
"""Return a string representation of the cached dictionary."""
r = "CachedEyeDataDict(vars=%i,n=%i,shape=%s): \n" % (len(self.data), self.length, str(self.shape))
r += f" Cache dir: {self._cache_dir}\n"
r += f" Memory limit: {self._max_memory_bytes / (1024*1024):.1f} MB\n"
r += f" Current memory: {self._current_memory_bytes / (1024*1024):.1f} MB\n"
r += f" Arrays in memory: {len(self._in_memory_data)}\n"
r += f" Arrays on disk: {len(self._h5_file['data'])}\n"
for k, v in self.data.items():
r += " %s (%s): " % (k, v.dtype)
r += ", ".join(v.flat[0:(min(5, self.length))].astype(str).tolist())
if self.length > 5:
r += "..."
r += "\n"
return r