|
@@ -110,7 +110,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
def group_size(self):
|
|
|
return len(self.ordered_group_endpoints)
|
|
|
|
|
|
- def _get_stub(self, peer: Endpoint) -> StubBase:
|
|
|
+ def _get_peer_stub(self, peer: Endpoint) -> StubBase:
|
|
|
return self._servicer.get_stub(self._p2p, peer)
|
|
|
|
|
|
async def run(self) -> AsyncIterator[torch.Tensor]:
|
|
@@ -152,7 +152,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
else:
|
|
|
loop = asyncio.get_event_loop()
|
|
|
code = None
|
|
|
- stream = self._get_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
|
|
|
+ stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
|
|
|
async for part_index, msg in aenumerate(stream):
|
|
|
if code is None:
|
|
|
code = msg.code
|
|
@@ -229,7 +229,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
|
|
|
async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
|
|
|
error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
|
|
|
- async for _ in self._get_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
|
|
|
+ async for _ in self._get_peer_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
|
|
|
pass
|
|
|
|
|
|
def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
|