Просмотр исходного кода

fix autograd backward & black & isort

dbaranchuk 3 лет назад
Родитель
Сommit
648a5d6e1d
3 измененных файлов с 36 добавлено и 17 удалено
  1. 23 12
      src/client/sequential_autograd.py
  2. 1 1
      src/server/backend.py
  3. 12 4
      src/server/server.py

+ 23 - 12
src/client/sequential_autograd.py

@@ -3,8 +3,8 @@ import logging
 from typing import List, Optional, Sequence, Tuple
 
 import torch
-from hivemind import serialize_torch_tensor
-from hivemind.moe.client.expert import expert_backward, expert_forward, _forward_stream
+from hivemind import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.moe.client.expert import _forward_stream, expert_backward, expert_forward
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
@@ -39,11 +39,14 @@ async def run_expert_forward(
     forward_inputs = nested_flatten(forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
 
+    # Asynchronous serialization
     loop = asyncio.get_running_loop()
-    serialized_tensors = await asyncio.gather(*(
-        loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-        for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
-    ))
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+        )
+    )
 
     deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
     flat_outputs = tuple(deserialized_outputs)
@@ -67,12 +70,15 @@ async def run_expert_backward(
     inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
     backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
 
+    # Asynchronous serialization
     loop = asyncio.get_running_loop()
-    serialized_tensors = await asyncio.gather(*(
-        loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-        for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-    ))
-    
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
+    )
+
     deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
     return deserialized_grad_inputs
 
@@ -189,7 +195,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     @staticmethod
     def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
-        input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
+        input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
 
         sequence_manager.rpc_info  # lazy init
         outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
@@ -217,6 +223,11 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         grad_input_batches = RemoteExpertWorker.run_coroutine(
             _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
         )
+        # grad_input_batches = [sequential_backward((grad_output,), input_batch, spans, ctx.sequence_manager)
+        #     for grad_output, input_batch, spans in zip(
+        #         grad_output_batches, intermediate_input_batches, forward_sequences
+        #     )
+        # ]
         grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
         grad_inputs = torch.cat(grad_inputs, dim=0)
         return (grad_inputs, None)

+ 1 - 1
src/server/backend.py

@@ -1,6 +1,6 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
-from typing import Sequence, Tuple, Optional
+from typing import Optional, Sequence, Tuple
 
 import torch
 from hivemind import use_hivemind_log_handler

+ 12 - 4
src/server/server.py

@@ -13,7 +13,7 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import BloomConfig, declare_active_modules, BloomBlock
+from src import BloomBlock, BloomConfig, declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
@@ -194,10 +194,18 @@ class Server(threading.Thread):
                 module_uid,
                 block,
                 memory_cache=memory_cache,
-                backend_dtype=None if torch_dtype == 'auto' else torch_dtype,
-                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression),),
+                backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                args_schema=(
+                    BatchTensorDescriptor(
+                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                    ),
+                ),
                 kwargs_schema={},
-                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression),),
+                outputs_schema=(
+                    BatchTensorDescriptor(
+                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                    ),
+                ),
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )