Selaa lähdekoodia

wip some more

Your Name 1 vuosi sitten
vanhempi
commit
49474e5477

+ 11 - 9
src/petals/client/inference_session.py

@@ -4,7 +4,7 @@ import asyncio
 import itertools
 import time
 import uuid
-from typing import AsyncIterator, List, Optional, Tuple
+from typing import AsyncIterator, List, Optional, Tuple, Sequence
 
 import torch
 from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
@@ -34,7 +34,7 @@ class _ServerInferenceSession:
         self,
         config: ClientConfig,
         span: RemoteSpanInfo,
-        uid: ModuleUID,
+        span_uids: Sequence[ModuleUID],
         rpc_info: RPCInfo,
         inputs_queue: asyncio.Queue,
         outputs_aiter: AsyncIterator,
@@ -43,8 +43,8 @@ class _ServerInferenceSession:
         **metadata,
     ):
         self.config = config
-        self.span, self.uid, self.rpc_info = span, uid, rpc_info
-        self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
+        self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
+        self.num_blocks = len(span_uids)
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
         self.session_id = str(uuid.uuid4())
@@ -62,18 +62,19 @@ class _ServerInferenceSession:
         config: ClientConfig,
         p2p: P2P,
         span: RemoteSpanInfo,
-        uid: ModuleUID,
+        span_uids: Sequence[RemoteSpanInfo],
         rpc_info: RPCInfo,
         **metadata,
     ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        # TODO YOZH you don't need rpc info here
         stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
         inputs_queue = asyncio.Queue()
         outputs_stream = await asyncio.wait_for(
             stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
             config.connect_timeout,
         )
-        return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
+        return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@@ -142,6 +143,7 @@ class _ServerInferenceSession:
 
         request_metadata["args_structure"] = args_structure
 
+        # TODO YOZH FIX THIS BEFORE THE END OF THIS PR
         # TODO: make possible to use different compression method for different tensors
         server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
         compression = server_side_inference_schema[0].compression
@@ -155,7 +157,7 @@ class _ServerInferenceSession:
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
-                    uid=self.uid,
+                    uid=CHAIN_DELIMITER.join(self.span_uids),
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(input_tensors, inference_schema)
@@ -244,8 +246,8 @@ class InferenceSession:
         server_sessions = []
         try:
             for span in chosen_spans:
-                span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
-                metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
+                span_uids = self._sequence_manager.block_uids[span.start : span.end]
+                metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids)
                 session = RemoteExpertWorker.run_coroutine(
                     _ServerInferenceSession.create(
                         self._sequence_manager.config,

+ 47 - 35
src/petals/client/remote_forward_backward.py

@@ -2,50 +2,49 @@
 Utility functions that call RPC forward or backward on a single remote server
 """
 import asyncio
-from typing import Iterable, List, Optional, Sequence, Tuple
+from typing import Iterable, List, Sequence, Tuple
 
 import torch
-from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor, PeerID
+from hivemind import PeerID, nested_flatten, serialize_torch_tensor
 from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
-from hivemind.p2p import StubBase
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
+from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.streaming import split_for_streaming
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
-from petals import RemoteSequenceManager
 from petals.client.config import ClientConfig
-from petals.data_structures import ModuleUID, RPCInfo, CHAIN_DELIMITER
+from petals.client.routing import RemoteSequenceManager
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.packaging import pack_args_kwargs
 
 
 async def _forward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
         timeout=config.request_timeout,
     )
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
 async def _backward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
         timeout=config.request_timeout,
     )
     return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
 
 
 async def _forward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     parts = (
-        runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
@@ -55,10 +54,10 @@ async def _forward_stream(
 
 
 async def _backward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     parts = (
-        runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
@@ -81,35 +80,39 @@ async def run_remote_forward(
     """
     merged_uid = CHAIN_DELIMITER.join(span_uids)
     stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
-    flat_inputs, args_structure = pack_args_kwargs(*args, **kwargs)
     metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs)
-    compressions = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
-    if compressions is None:
-        compressions = [runtime_pb2.CompressionType.NONE] * len(flat_inputs)
-    compressions = list(nested_flatten(compressions))
-    assert len(compressions) == len(flat_inputs), f"got {len(flat_inputs)} tensors but {len(compressions)} codecs"
-    inputs = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_inputs)
+    codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
+    flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
+    flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
+    args_structure = metadata.setdefault("args_structure", args_structure)
+    if codecs is None:
+        codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
+    else:
+        codecs = list(nested_flatten(codecs))
+        assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     serialized_tensors = await asyncio.gather(
         *(
             loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
-            for tensor, compression in zip(inputs, compressions)
+            for tensor, compression in zip(flat_tensors, codecs)
         )
     )
 
     # call RPC on remote server
-    size = sum(t.element_size() * t.nelement() for t in inputs)
+    size = sum(t.element_size() * t.nelement() for t in flat_tensors)
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
-    return await forward_fn(merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata)
+    return await forward_fn(
+        merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
+    )
 
 
 async def run_remote_backward(
     sequence_manager: RemoteSequenceManager,
+    peer_id: PeerID,
     span_uids: Sequence[ModuleUID],
-    stub: StubBase,
     grad_outputs: Sequence[torch.Tensor],
     *args: torch.Tensor,
     **kwargs: torch.Tensor,
@@ -119,23 +122,32 @@ async def run_remote_backward(
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
-    flat_tensors, args_structure = pack_args_kwargs(
-        [grad.cpu() for grad in grad_outputs], args, kwargs
-    )
-    metadata = sequence_manager.get_request_metadata(
-        "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
-    )
+    merged_uid = CHAIN_DELIMITER.join(span_uids)
+    stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
+    metadata = sequence_manager.get_request_metadata(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
+    codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
+    flat_tensors, args_structure = pack_args_kwargs(grad_outputs, *args, **kwargs)
+    flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
+    args_structure = metadata.setdefault("args_structure", args_structure)
+
+    if codecs is None:
+        codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
+    else:
+        codecs = list(nested_flatten(codecs))
+        assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     serialized_tensors = await asyncio.gather(
         *(
-            loop.run_in_executor(None, serialize_torch_tensor, compression)
-            for tensor, proto in zip(flat_inputs_and_grad_outputs, backward_schema)
+            loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
+            for tensor, compression in zip(flat_tensors, codecs)
         )
     )
 
-    size = sum(t.element_size() * t.nelement() for t in flat_inputs_and_grad_outputs)
+    size = sum(t.element_size() * t.nelement() for t in flat_tensors)
     backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
-    return await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
+    return await backward_fn(
+        merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
+    )

+ 5 - 2
src/petals/client/routing/sequence_manager.py

@@ -474,7 +474,9 @@ class RemoteSequenceManager:
             return 0
         return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
 
-    def get_request_metadata(self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Dict[str, Any]]:
+    def get_request_metadata(
+        self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
+    ) -> Optional[Dict[str, Any]]:
         """
         :param peer_id: remote server's PeerID
         :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
@@ -488,7 +490,8 @@ class RemoteSequenceManager:
         )
 
     def get_compression_codecs(
-            self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
+        self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
+    ) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
         """
         :param peer_id: remote server's PeerID
         :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"

+ 20 - 27
src/petals/client/sequential_autograd.py

@@ -4,19 +4,16 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 import asyncio
 import itertools
 from collections import deque
-from typing import List, Optional, Sequence, Tuple, Dict, Any
+from typing import Any, Dict, List, Optional, Sequence, Tuple
 
 import torch
-from hivemind import MSGPackSerializer
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.utils.logging import get_logger
 
 from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
 from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
-from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
-from petals.server.handler import TransformerConnectionHandler
+from petals.data_structures import RemoteSpanInfo
 from petals.utils.misc import DUMMY, is_dummy
-from petals.utils.packaging import pack_args_kwargs
 
 logger = get_logger(__name__)
 
@@ -24,12 +21,12 @@ MAX_TOKENS_IN_BATCH = 1024
 
 
 async def sequential_forward(
+    sequence_manager: RemoteSequenceManager,
     inputs: torch.Tensor,
     prompts: torch.Tensor,
-    sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
     end_index: Optional[int] = None,
-    block_kwargs: Sequence[Dict[str, Any]] = (),
+    *block_kwargs: Dict[str, Any],
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     Constructs a routing path from <start_index> to <end_index>.
@@ -45,13 +42,6 @@ async def sequential_forward(
     :param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block
     """
 
-    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
-    assert len(block_kwargs) in (
-        0,
-        1,
-        end_index - start_index,
-    ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
-
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
     inputs = inputs.cpu()
@@ -59,6 +49,9 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+    assert len(block_kwargs) in (0, 1, end_index - start_index), \
+        f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
+    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert is_dummy(prompts) or len(prompts) == len(
         sequence_manager.block_uids
     )  # should be n_layers - 1 but add extra prompts for convenience
@@ -87,13 +80,13 @@ async def sequential_forward(
                     sequence_manager.block_uids[span.start : span.end],
                     inputs,
                     prompts[span.start : span.end],
-                    *block_kwargs[span.start : span.end]
+                    *block_kwargs[span.start : span.end],
                 )
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
 
-                # Save intermediate inputs and subsequ_peerences if the forward is already done for them
+                # Save intermediate inputs and subsequences if the forward is already done for them
                 intermediate_inputs.append(inputs)
                 done_sequences.append(span)
 
@@ -118,11 +111,12 @@ async def sequential_forward(
 
 
 async def sequential_backward(
+    sequence_manager: RemoteSequenceManager,
+    forward_sequences: List[RemoteSpanInfo],
     grad_outputs: Sequence[torch.Tensor],
     intermediate_inputs: List[torch.Tensor],
     prompts: torch.Tensor,
-    forward_sequences: List[RemoteSpanInfo],
-    sequence_manager: RemoteSequenceManager,
+    *block_kwargs: Dict[str, Any],
 ) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
     """
     Performs chained backward for each forward subsequence.
@@ -148,7 +142,7 @@ async def sequential_backward(
             try:
                 if attempt_no >= 1:
                     _, backup_inputs, backup_sequences = await sequential_forward(
-                        inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+                        sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end
                     )
                     assert len(backup_inputs) == len(backup_sequences)
                     assert backup_sequences[0].start == span.start
@@ -159,14 +153,13 @@ async def sequential_backward(
                     inputs = intermediate_inputs.pop()
                     span = forward_sequences.pop()
 
-
-                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
-                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     sequence_manager,
-                    sequence_manager.block_uids[span.start: span.end],
-                    span_uids,
-                    grad_outputs, inputs,
+                    span.peer_id,
+                    sequence_manager.block_uids[span.start : span.end],
+                    grad_outputs,
+                    *inputs,
+                    *block_kwargs[span.start : span.end],
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts_reversed.extend(span_grad_prompts)
@@ -198,7 +191,7 @@ async def _gather_forward(input_batches, prompt_batches, sequence_manager):
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
     return await asyncio.gather(
         *[
-            sequential_forward(input_batch, prompt_batch, sequence_manager)
+            sequential_forward(sequence_manager, input_batch, prompt_batch)
             for input_batch, prompt_batch in zip(input_batches, prompt_batches)
         ]
     )
@@ -210,7 +203,7 @@ async def _gather_backward(
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
         *[
-            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
+            sequential_backward(sequence_manager, spans, (grad_output,), input_batch, prompt_batch)
             for grad_output, input_batch, prompt_batch, spans in zip(
                 grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
             )