Преглед на файлове

Lower payload size threshold for stream handlers (#251)

Hotfix: we add "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space.
Alexander Borzunov преди 2 години
родител
ревизия
4091db10bf
променени са 1 файла, в които са добавени 4 реда и са изтрити 2 реда
  1. 4 2
      src/petals/client/remote_forward_backward.py

+ 4 - 2
src/petals/client/remote_forward_backward.py

@@ -108,7 +108,8 @@ async def run_remote_forward(
 
     # call RPC on remote server
     size = sum(t.element_size() * t.nelement() for t in inputs)
-    forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
+    forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
+    # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
     deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
@@ -150,6 +151,7 @@ async def run_remote_backward(
     )
 
     size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
-    backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary
+    backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
+    # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
     deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
     return deserialized_grad_inputs