|
@@ -1,5 +1,5 @@
|
|
|
import contextlib
|
|
|
-from typing import AsyncIterator, Dict, List, Optional, Sequence
|
|
|
+from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
import torch
|
|
|
from hivemind import (
|
|
@@ -244,7 +244,9 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
-async def _rpc_backward(*flat_tensors: torch.Tensor, requested_backends):
|
|
|
+async def _rpc_backward(
|
|
|
+ *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
|
|
|
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
inputs, grad_outputs, *prompts = flat_tensors
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|