justheuristic 3 éve
szülő
commit
64b04c37cf
1 módosított fájl, 12 hozzáadás és 5 törlés
  1. 12 5
      src/server/handler.py

+ 12 - 5
src/server/handler.py

@@ -1,9 +1,16 @@
 import contextlib
-from typing import AsyncIterator, Dict, Sequence, Optional, List
+from typing import AsyncIterator, Dict, List, Optional, Sequence
 
 import torch
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, \
-    serialize_torch_tensor, MSGPackSerializer
+from hivemind import (
+    DHT,
+    MSGPackSerializer,
+    P2PContext,
+    TensorDescriptor,
+    deserialize_torch_tensor,
+    nested_flatten,
+    serialize_torch_tensor,
+)
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
@@ -231,7 +238,7 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     for backend, prompt in zip(requested_backends, prompts):
         (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         if not is_dummy(prompt):
-            hidden_states[:, :min(seq_length, prompt.shape[1]), ...] += prompt
+            hidden_states[:, : min(seq_length, prompt.shape[1]), ...] += prompt
         assert isinstance(hidden_states, torch.Tensor)
         assert hidden_states.ndim == 3, f"{type(backend)} must return a list with a single 3d tensor of hidden states"
 
@@ -256,7 +263,7 @@ async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_output
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
-            inputs = inputs.clone() # TODO
+            inputs = inputs.clone()  # TODO
             inputs[:, :pre_seq_len] += prompt
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)