justheuristic 2 سال پیش
والد
کامیت
4ab44ebd0f
5فایلهای تغییر یافته به همراه130 افزوده شده و 132 حذف شده
  1. 1 1
      src/client/__init__.py
  2. 121 42
      src/client/remote_forward_backward.py
  3. 3 1
      src/client/sequence_manager.py
  4. 4 87
      src/client/sequential_autograd.py
  5. 1 1
      src/client/spending_policy.py

+ 1 - 1
src/client/__init__.py

@@ -2,4 +2,4 @@ from src.client.inference_session import RemoteSequentialInferenceSession, Remot
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
-from src.client.spending_policy import DummySpendingPolicy, SpendingPolicyBase
+from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 121 - 42
src/client/remote_forward_backward.py

@@ -1,55 +1,77 @@
 """
 """
 Utility functions that call RPC forward or backward on a single remote server
 Utility functions that call RPC forward or backward on a single remote server
 """
 """
-from typing import Iterable, List, Sequence
+import asyncio
+from typing import Iterable, List, Sequence, Tuple, Optional
 
 
 import torch
 import torch
+from hivemind import nested_compare, nested_flatten, serialize_torch_tensor, nested_pack
 from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
 from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
+from hivemind.p2p import StubBase
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
 from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 from hivemind.utils.streaming import split_for_streaming
 
 
-
-async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
-
-    grad_inputs = await stub.rpc_backward_stream(
-        amap_in_executor(
-            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
-            iter_as_aiter(split),
-        ),
+from src.data_structures import ModuleUID, RPCInfo
+
+
+async def run_remote_forward(
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b'', **kwargs
+) -> Tuple[torch.Tensor, ...]:
+    """
+    Serializes input tensors and calls "rpc_forward" on a remote server.
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+    # detach to avoid pickling the computation graph
+    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
+    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
+
+    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+    forward_inputs = (inputs, kwargs)
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    if not nested_compare(forward_inputs, forward_schema_with_prompts):
+        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+    forward_inputs = nested_flatten(forward_inputs)
+    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
+        )
     )
     )
-    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
-    return await deserialize_tensor_stream(tensors_stream)
 
 
+    # call RPC on remote server
+    size = sum(t.element_size() * t.nelement() for t in inputs)
+    if size > MAX_UNARY_PAYLOAD_SIZE:
+        deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
+    else:
+        deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
 
 
-async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
-    )
-    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+    return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
 
 
-async def remote_backward(
-    uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+async def _forward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
-    """Call rpc_backward (unary or stream) on a single remote server, return grads w.r.t. arguments"""
-    size = 0
-    for t in inputs_and_grads:
-        size += t.element_size() * t.nelement()
-        if size > MAX_UNARY_PAYLOAD_SIZE:
-            return await _backward_stream(uid, serialized_tensors, stub)
-    else:
-        return await _backward_unary(uid, serialized_tensors, stub)
-
-
-async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
     split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
     split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
 
 
     outputs = await stub.rpc_forward_stream(
     outputs = await stub.rpc_forward_stream(
         amap_in_executor(
         amap_in_executor(
-            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
             iter_as_aiter(split),
             iter_as_aiter(split),
         ),
         ),
     )
     )
@@ -58,21 +80,78 @@ async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
     return await deserialize_tensor_stream(tensors_stream)
     return await deserialize_tensor_stream(tensors_stream)
 
 
 
 
-async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+async def _forward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
     )
     )
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
 
 
-async def remote_forward(
-    uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+
+async def _backward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
-    """Call rpc_forward (unary or stream) on a single remote server, return block outputs"""
-    size = 0
-    for t in inputs:
-        size += t.element_size() * t.nelement()
-        if size > MAX_UNARY_PAYLOAD_SIZE:
-            return await _forward_stream(uid, serialized_tensors, stub)
+    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
+            iter_as_aiter(split),
+        ),
+    )
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def run_remote_backward(
+    uid: ModuleUID,
+    stub: StubBase,
+    rpc_info: RPCInfo,
+    inputs: torch.Tensor,
+    grad_outputs: List[torch.Tensor],
+    *extra_tensors: torch.Tensor,
+    metadata: bytes = b''
+) -> Sequence[torch.Tensor]:
+    """
+    Serializes grad outputs and calls "rpc_backward" on a remote server.
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    # TODO generalize this
+    prompts_schema = next(iter(args_schema))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
+    )
+
+    size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
+    if size > MAX_UNARY_PAYLOAD_SIZE:
+        deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
     else:
     else:
-        return await _forward_unary(uid, serialized_tensors, stub)
+        deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
+
+    return deserialized_grad_inputs
+
+
+async def _backward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

+ 3 - 1
src/client/sequence_manager.py

@@ -9,6 +9,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
+from src import NoSpendingPolicy
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.dht_utils import get_remote_module_infos
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
@@ -30,6 +31,7 @@ class RemoteSequenceManager:
         self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
         self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
         self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
         self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
         self.last_update_time: DHTExpiration = -float("inf")
         self.last_update_time: DHTExpiration = -float("inf")
+        self.spending_policy = NoSpendingPolicy()
         self.max_retries = max_retries
         self.max_retries = max_retries
         self._rpc_info = None
         self._rpc_info = None
         self.lock_changes = threading.Lock()
         self.lock_changes = threading.Lock()
@@ -39,7 +41,7 @@ class RemoteSequenceManager:
             assert info is not None, f"Found no remote peers for block {uid}"
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
         assert self.spans_by_priority and self.spans_containing_block
 
 
-    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
         """
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
         Form a sequence of remote servers that collectively serve all consecutive layers
 
 

+ 4 - 87
src/client/sequential_autograd.py

@@ -6,100 +6,17 @@ import logging
 from typing import List, Optional, Sequence, Tuple
 from typing import List, Optional, Sequence, Tuple
 
 
 import torch
 import torch
-from hivemind import serialize_torch_tensor
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import StubBase
-from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 
 
-from src.client.remote_forward_backward import remote_backward, remote_forward
+from src.client.remote_forward_backward import run_remote_forward, run_remote_backward
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
 from src.utils.misc import DUMMY, is_dummy
 from src.utils.misc import DUMMY, is_dummy
 
 
 MAX_TOKENS_IN_BATCH = 1024
 MAX_TOKENS_IN_BATCH = 1024
 
 
 
 
-async def run_remote_forward(
-    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
-) -> Tuple[torch.Tensor, ...]:
-    """
-    Serializes input tensors and calls "rpc_forward" on a remote server.
-    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
-    """
-
-    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-    # detach to avoid pickling the computation graph
-    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
-    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
-
-    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-    forward_inputs = (inputs, kwargs)
-
-    # Modify forward_schema to support prompts
-    args_schema, kwargs_schema = rpc_info["forward_schema"]
-    # TODO: rm this assert when support arbitrary number of input tensors
-    assert len(args_schema) == 1 and len(inputs) == 2
-    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
-
-    if not nested_compare(forward_inputs, forward_schema_with_prompts):
-        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-
-    forward_inputs = nested_flatten(forward_inputs)
-    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
-
-    # Asynchronous serialization
-    loop = asyncio.get_running_loop()
-    serialized_tensors = await asyncio.gather(
-        *(
-            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
-        )
-    )
-
-    deserialized_outputs = await remote_forward(uid, inputs, serialized_tensors, stub)
-    flat_outputs = tuple(deserialized_outputs)
-    return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
-
-
-async def run_remote_backward(
-    uid: ModuleUID,
-    stub: StubBase,
-    rpc_info: RPCInfo,
-    inputs: torch.Tensor,
-    grad_outputs: List[torch.Tensor],
-    *extra_tensors: torch.Tensor,
-) -> Sequence[torch.Tensor]:
-    """
-    Serializes grad outputs and calls "rpc_backward" on a remote server.
-    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
-    """
-
-    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
-
-    # Modify forward_schema to support prompts
-    args_schema, kwargs_schema = rpc_info["forward_schema"]
-    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
-    # TODO generalize this
-    prompts_schema = next(iter(args_schema))
-    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
-
-    # Asynchronous serialization
-    loop = asyncio.get_running_loop()
-    serialized_tensors = await asyncio.gather(
-        *(
-            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        )
-    )
-
-    deserialized_grad_inputs = await remote_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
-    return deserialized_grad_inputs
-
-
 async def sequential_forward(
 async def sequential_forward(
     inputs: torch.Tensor,
     inputs: torch.Tensor,
     prompts: torch.Tensor,
     prompts: torch.Tensor,
@@ -127,9 +44,9 @@ async def sequential_forward(
 
 
     while len(sequences) > 0:
     while len(sequences) > 0:
         while True:
         while True:
+            span = sequences.pop(0)
+            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
             try:
             try:
-                span = sequences.pop(0)
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
 

+ 1 - 1
src/client/spending_policy.py

@@ -9,6 +9,6 @@ class SpendingPolicyBase(ABC):
         pass
         pass
 
 
 
 
-class DummySpendingPolicy(SpendingPolicyBase):
+class NoSpendingPolicy(SpendingPolicyBase):
     def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
     def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
         return 0.0
         return 0.0