瀏覽代碼

s/expert/block/g

justheuristic 2 年之前
父節點
當前提交
2906cec6fe
共有 1 個文件被更改,包括 8 次插入7 次删除
  1. 8 7
      src/server/handler.py

+ 8 - 7
src/server/handler.py

@@ -46,15 +46,15 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> Tuple[str, List[torch.Tensor], Dict]:
-        expert_uid, metadata = None, None
+        block_uid, metadata = None, None
 
         def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
-            nonlocal expert_uid, metadata
+            nonlocal block_uid, metadata
 
-            if expert_uid is None:
-                expert_uid = req.uid
-            elif expert_uid != req.uid:
-                raise ValueError("Expert uids differ in one request")
+            if block_uid is None:
+                block_uid = req.uid
+            elif block_uid != req.uid:
+                raise ValueError("Block uids differ in one request")
 
             if metadata is None:
                 metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
@@ -63,7 +63,8 @@ class TransformerConnectionHandler(ConnectionHandler):
 
         tensors_stream = amap_in_executor(_unpack, requests)
         inputs = await deserialize_tensor_stream(tensors_stream)
-        return expert_uid, inputs, metadata
+        assert isinstance(block_uid, str) and isinstance(metadata, dict)
+        return block_uid, inputs, metadata
 
     async def rpc_inference(
         self,