justheuristic 2 năm trước cách đây
mục cha
commit
65012ac88e
2 tập tin đã thay đổi với 4 bổ sung1 xóa
  1. 3 0
      src/client/remote_forward_backward.py
  2. 1 1
      src/utils/convert_8bit.py

+ 3 - 0
src/client/remote_forward_backward.py

@@ -58,6 +58,7 @@ async def run_remote_forward(
     size = sum(t.element_size() * t.nelement() for t in inputs)
     if size > MAX_UNARY_PAYLOAD_SIZE:
         deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
+        raise 123
     else:
         deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
 
@@ -83,9 +84,11 @@ async def _forward_stream(
 async def _forward_unary(
     uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
 ) -> List[torch.Tensor]:
+    print(end='client - forward - before\n', flush=True)
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
         runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
     )
+    print(end='client - forward - after\n', flush=True)
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 

+ 1 - 1
src/utils/convert_8bit.py

@@ -1,4 +1,4 @@
-import bitsandbytes as bnb
+# import bitsandbytes as bnb
 import torch