|
@@ -2,12 +2,12 @@ import asyncio
|
|
|
import contextlib
|
|
|
import faulthandler
|
|
|
import math
|
|
|
-import torch
|
|
|
import multiprocessing as mp
|
|
|
-import numpy as np
|
|
|
-
|
|
|
from typing import Any, Iterable, Optional, Sequence
|
|
|
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+
|
|
|
import hivemind
|
|
|
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
|
|
|
from hivemind.averaging.control import AveragingStage, StepControl
|
|
@@ -64,16 +64,20 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
self.rank = averager_rank
|
|
|
self.parameters = tuple(parameters)
|
|
|
self._uncompressed_gradients = set(
|
|
|
- i for i, grad in enumerate(self._grads_from_parameters())
|
|
|
- if len(tuple(grad.size())) == 1 or
|
|
|
- (self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio)
|
|
|
+ i
|
|
|
+ for i, grad in enumerate(self._grads_from_parameters())
|
|
|
+ if len(tuple(grad.size())) == 1
|
|
|
+ or (
|
|
|
+ self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio
|
|
|
+ )
|
|
|
)
|
|
|
self._gradient_rests = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
|
|
|
self._qs = list(
|
|
|
torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
|
|
|
- for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
|
|
|
+ for idx, grad in enumerate(self._grads_from_parameters())
|
|
|
+ if idx not in self._uncompressed_gradients
|
|
|
)
|
|
|
- for tensor in (self._qs + self._gradient_rests):
|
|
|
+ for tensor in self._qs + self._gradient_rests:
|
|
|
if tensor is not None:
|
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
|
tensor.share_memory_()
|
|
@@ -87,22 +91,22 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
client_mode=client_mode,
|
|
|
warn=warn,
|
|
|
grad_extra_tensors=grad_extra_tensors,
|
|
|
- **kwargs
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def _register_allreduce_group(self, group_info: GroupInfo):
|
|
|
"""registers a given all-reduce runner to listen for incoming connections"""
|
|
|
try:
|
|
|
- self._running_groups[group_info.group_id + b'.phase1'] = asyncio.Future()
|
|
|
- self._running_groups[group_info.group_id + b'.phase2'] = asyncio.Future()
|
|
|
+ self._running_groups[group_info.group_id + b".phase1"] = asyncio.Future()
|
|
|
+ self._running_groups[group_info.group_id + b".phase2"] = asyncio.Future()
|
|
|
self._pending_groups_registered.set()
|
|
|
yield
|
|
|
finally:
|
|
|
- maybe_future = self._running_groups.pop(group_info.group_id + b'.phase1', None)
|
|
|
+ maybe_future = self._running_groups.pop(group_info.group_id + b".phase1", None)
|
|
|
if maybe_future and not maybe_future.done():
|
|
|
logger.warning(f"All-reduce group {group_info.group_id + b'.phase1'} did not finish.")
|
|
|
- maybe_future = self._running_groups.pop(group_info.group_id + b'.phase2', None)
|
|
|
+ maybe_future = self._running_groups.pop(group_info.group_id + b".phase2", None)
|
|
|
if maybe_future and not maybe_future.done():
|
|
|
logger.warning(f"All-reduce group {group_info.group_id + b'.phase2'} did not finish.")
|
|
|
self._pending_groups_registered.set()
|
|
@@ -122,18 +126,22 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
)
|
|
|
|
|
|
async with enter_asynchronously(self.get_tensors()) as local_tensors:
|
|
|
- compressed_tensors = [lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
|
|
|
+ compressed_tensors = [
|
|
|
+ lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients
|
|
|
+ ]
|
|
|
|
|
|
cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
|
|
|
ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
|
|
|
for p, q, rest in zip(ps, self._qs, cs):
|
|
|
torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
|
|
|
- first_all_reduced = ps + [rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients]
|
|
|
+ first_all_reduced = ps + [
|
|
|
+ rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients
|
|
|
+ ]
|
|
|
allreduce1 = AllReduceRunner(
|
|
|
p2p=self._p2p,
|
|
|
servicer_type=type(self),
|
|
|
prefix=self.prefix,
|
|
|
- group_id=group_info.group_id + b'.phase1',
|
|
|
+ group_id=group_info.group_id + b".phase1",
|
|
|
tensors=first_all_reduced,
|
|
|
ordered_peer_ids=group_info.peer_ids,
|
|
|
peer_fractions=peer_fractions,
|
|
@@ -141,7 +149,7 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
modes=modes,
|
|
|
**kwargs,
|
|
|
)
|
|
|
- self._running_groups[group_info.group_id + b'.phase1'].set_result(allreduce1)
|
|
|
+ self._running_groups[group_info.group_id + b".phase1"].set_result(allreduce1)
|
|
|
|
|
|
if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
|
|
@@ -163,7 +171,7 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
p2p=self._p2p,
|
|
|
servicer_type=type(self),
|
|
|
prefix=self.prefix,
|
|
|
- group_id=group_info.group_id + b'.phase2',
|
|
|
+ group_id=group_info.group_id + b".phase2",
|
|
|
tensors=self._qs,
|
|
|
ordered_peer_ids=group_info.peer_ids,
|
|
|
peer_fractions=peer_fractions,
|
|
@@ -171,7 +179,7 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
modes=modes,
|
|
|
**kwargs,
|
|
|
)
|
|
|
- self._running_groups[group_info.group_id + b'.phase2'].set_result(allreduce2)
|
|
|
+ self._running_groups[group_info.group_id + b".phase2"].set_result(allreduce2)
|
|
|
|
|
|
if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
|
|
@@ -212,8 +220,8 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
def orthogonalize(matrix, eps=torch.tensor(1e-8)):
|
|
|
n, m = matrix.shape
|
|
|
for i in range(m):
|
|
|
- col = matrix[:, i: i + 1]
|
|
|
+ col = matrix[:, i : i + 1]
|
|
|
col /= torch.sqrt(torch.sum(col ** 2)) + eps
|
|
|
if i + 1 < m:
|
|
|
- rest = matrix[:, i + 1:]
|
|
|
+ rest = matrix[:, i + 1 :]
|
|
|
rest -= torch.sum(col * rest, dim=0) * col
|