cache.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """
  2. A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
  3. For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
  4. TODO In future, one could modify cache to implement, among other things,
  5. - in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
  6. -- note: this can be done using mp.Condtion
  7. - allocate cache as one contigous buffer to avoid fragmentation
  8. - quantize cached values using bitsandbytes
  9. - LRU offloading from gpu to ram
  10. """
  11. import contextlib
  12. import ctypes
  13. import multiprocessing as mp
  14. import os
  15. from typing import Dict, Optional, Union
  16. import hivemind
  17. import torch
  18. from hivemind import use_hivemind_log_handler
  19. from hivemind.utils import TensorDescriptor, get_logger
  20. use_hivemind_log_handler("in_root_logger")
  21. logger = get_logger(__name__)
  22. Handle = int
  23. class MemoryCache:
  24. """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
  25. def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
  26. self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
  27. self.device = device
  28. self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
  29. self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
  30. self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
  31. self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
  32. self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
  33. self.runtime_pid = os.getpid()
  34. self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
  35. self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
  36. @property
  37. def current_size_bytes(self) -> int:
  38. return self._current_size.value
  39. @current_size_bytes.setter
  40. def current_size_bytes(self, value: int):
  41. self._current_size.value = value
  42. @property
  43. def handle_counter(self) -> int:
  44. return self._handle_counter.value
  45. @handle_counter.setter
  46. def handle_counter(self, value: int):
  47. self._handle_counter.value = value
  48. @contextlib.asynccontextmanager
  49. async def allocate_cache(self, descr: TensorDescriptor) -> Handle:
  50. """
  51. Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
  52. :param descr: allocate a tensor of this size, dtype, etc
  53. :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
  54. Furthermore, it can be called concurrently with at most one use_cache call in runtime.
  55. """
  56. assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
  57. assert descr.device is None and descr
  58. allocated_handle = None
  59. allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
  60. try:
  61. async with hivemind.utils.enter_asynchronously(self.lock_metadata):
  62. if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
  63. raise AllocationFailed(
  64. f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
  65. f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
  66. )
  67. allocated_handle = int(self.handle_counter)
  68. self.current_size_bytes += allocated_size_bytes
  69. self.handle_counter += 1 # note: this will eventually overflow and it is okay
  70. self._pending_messages.value += 1
  71. self._pipe_send.send((allocated_handle, descr))
  72. yield allocated_handle
  73. finally:
  74. if allocated_handle is not None:
  75. async with hivemind.utils.enter_asynchronously(self.lock_metadata):
  76. self._pending_messages.value += 1
  77. self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
  78. self.current_size_bytes -= allocated_size_bytes
  79. @contextlib.contextmanager
  80. def use_cache(self, handle: Handle) -> torch.Tensor:
  81. """
  82. Return a tensor that was previously allocated with try_allocate_cache,
  83. :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
  84. However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
  85. """
  86. assert os.getpid() == self.runtime_pid
  87. # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
  88. with self.lock_metadata:
  89. if self._allocated_tensors is None:
  90. self._allocated_tensors = {}
  91. # read creation/deletion requests from connection handlers
  92. for i in range(int(self._pending_messages.value)):
  93. recv_handle, recv_data = self._pipe_recv.recv()
  94. self._pending_messages.value -= 1
  95. if isinstance(recv_data, TensorDescriptor):
  96. self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
  97. elif recv_data is None:
  98. if recv_handle not in self._allocated_tensors:
  99. logger.warning(
  100. f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
  101. )
  102. self._allocated_tensors.pop(recv_handle, None)
  103. else:
  104. logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
  105. assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
  106. yield self._allocated_tensors[handle]
  107. class AllocationFailed(Exception):
  108. pass