浏览代码

Rename _get_stub to _get_peer_stub

Aleksandr Borzunov 4 年之前
父节点
当前提交
f615693e5b
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 3 3
      hivemind/averaging/allreduce.py

+ 3 - 3
hivemind/averaging/allreduce.py

@@ -110,7 +110,7 @@ class AllReduceRunner(ServicerBase):
     def group_size(self):
         return len(self.ordered_group_endpoints)
 
-    def _get_stub(self, peer: Endpoint) -> StubBase:
+    def _get_peer_stub(self, peer: Endpoint) -> StubBase:
         return self._servicer.get_stub(self._p2p, peer)
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
@@ -152,7 +152,7 @@ class AllReduceRunner(ServicerBase):
         else:
             loop = asyncio.get_event_loop()
             code = None
-            stream = self._get_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
             async for part_index, msg in aenumerate(stream):
                 if code is None:
                     code = msg.code
@@ -229,7 +229,7 @@ class AllReduceRunner(ServicerBase):
 
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
-        async for _ in self._get_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
+        async for _ in self._get_peer_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
             pass
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):