5
0
justheuristic 3 жил өмнө
parent
commit
50a647d410
1 өөрчлөгдсөн 4 нэмэгдсэн , 2 устгасан
  1. 4 2
      src/server/handler.py

+ 4 - 2
src/server/handler.py

@@ -1,5 +1,5 @@
 import contextlib
-from typing import AsyncIterator, Dict, List, Optional, Sequence
+from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
 
 import torch
 from hivemind import (
@@ -244,7 +244,9 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     return hidden_states
 
 
-async def _rpc_backward(*flat_tensors: torch.Tensor, requested_backends):
+async def _rpc_backward(
+    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
     inputs, grad_outputs, *prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)