streaming.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. """
  2. Utilities for streaming tensors
  3. """
  4. from __future__ import annotations
  5. from typing import Iterable, Iterator
  6. from hivemind.proto import runtime_pb2
  7. from hivemind.utils.logging import get_logger
  8. logger = get_logger(__name__)
  9. STREAMING_CHUNK_SIZE_BYTES = 2**16
  10. def split_for_streaming(
  11. serialized_tensor: runtime_pb2.Tensor,
  12. chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
  13. ) -> Iterator[runtime_pb2.Tensor]:
  14. """Split serialized_tensor into multiple chunks for streaming"""
  15. buffer = memoryview(serialized_tensor.buffer)
  16. num_chunks = len(range(0, len(buffer), chunk_size_bytes))
  17. yield runtime_pb2.Tensor(
  18. compression=serialized_tensor.compression,
  19. buffer=buffer[:chunk_size_bytes].tobytes(),
  20. chunks=num_chunks,
  21. size=serialized_tensor.size,
  22. dtype=serialized_tensor.dtype,
  23. requires_grad=serialized_tensor.requires_grad,
  24. )
  25. for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
  26. yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
  27. def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
  28. """Restore a result of split_into_chunks into a single serialized tensor"""
  29. stream = iter(stream)
  30. first_chunk = next(stream)
  31. serialized_tensor = runtime_pb2.Tensor()
  32. serialized_tensor.CopyFrom(first_chunk)
  33. buffer_chunks = [first_chunk.buffer]
  34. for tensor_part in stream:
  35. buffer_chunks.append(tensor_part.buffer)
  36. serialized_tensor.buffer = b"".join(buffer_chunks)
  37. return serialized_tensor