|
@@ -223,7 +223,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
@allow_state_sharing.setter
|
|
|
def allow_state_sharing(self, value: bool):
|
|
|
if value and self.client_mode:
|
|
|
- raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
|
|
|
+ raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state")
|
|
|
else:
|
|
|
old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
|
|
|
if value != old_value:
|
|
@@ -237,7 +237,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
@state_sharing_priority.setter
|
|
|
def state_sharing_priority(self, value: float):
|
|
|
if value and self.client_mode:
|
|
|
- raise ValueError("State sharing priority is unused: averager in client mode cannot share its state.")
|
|
|
+ raise ValueError("State sharing priority is unused: averager in client mode cannot share its state")
|
|
|
else:
|
|
|
old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
|
|
|
if self.allow_state_sharing and value != old_value:
|
|
@@ -279,7 +279,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if not self.client_mode:
|
|
|
await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
else:
|
|
|
- logger.debug(f"The averager is running in client mode.")
|
|
|
+ logger.debug("The averager is running in client mode")
|
|
|
|
|
|
self._matchmaking = Matchmaking(
|
|
|
self._p2p,
|
|
@@ -340,7 +340,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
self._inner_pipe.send(("_SHUTDOWN", None)) # shut down background thread in master
|
|
|
self.join(self.shutdown_timeout)
|
|
|
if self.is_alive():
|
|
|
- logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
|
|
|
+ logger.warning("Averager did not shut down within the grace period; terminating it the hard way")
|
|
|
self.terminate()
|
|
|
else:
|
|
|
logger.exception("Averager shutdown has no effect: the process is already not alive")
|
|
@@ -381,7 +381,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
|
"""
|
|
|
if self.mode == AveragingMode.AUX and weight is not None:
|
|
|
- logger.warning("Averager is running in auxiliary mode, weight is unused.")
|
|
|
+ logger.warning("Averager is running in auxiliary mode, weight is unused")
|
|
|
if scheduled_time is None:
|
|
|
scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
|
|
|
if weight is None:
|
|
@@ -449,7 +449,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
group_info = await matchmaking_task
|
|
|
|
|
|
if group_info is None:
|
|
|
- raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
+ raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
|
|
|
step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
|
|
@@ -502,7 +502,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
|
|
|
await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
|
|
|
except Exception as e:
|
|
|
- logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}.")
|
|
|
+ logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
|
|
|
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
@@ -690,7 +690,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
}
|
|
|
|
|
|
if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
|
|
|
- logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
|
|
|
+ logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}")
|
|
|
future.set_result(None)
|
|
|
return
|
|
|
|
|
@@ -715,7 +715,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
|
|
|
|
if not metadata:
|
|
|
- logger.debug(f"Peer {peer} did not send its state.")
|
|
|
+ logger.debug(f"Peer {peer} did not send its state")
|
|
|
continue
|
|
|
|
|
|
logger.info(f"Finished downloading state from {peer}")
|