|
@@ -15,6 +15,7 @@ from torch.utils.data import Dataset
|
|
|
import hivemind
|
|
|
from hivemind.optim.experimental.optimizer import Optimizer
|
|
|
from hivemind.utils.crypto import RSAPrivateKey
|
|
|
+from contextlib import nullcontext
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
@@ -26,6 +27,8 @@ class TrainingArguments:
|
|
|
num_clients: int = 3
|
|
|
target_batch_size: int = 128
|
|
|
reuse_grad_buffers: bool = True
|
|
|
+ delay_optimizer_step: bool = False
|
|
|
+ use_amp: bool = True
|
|
|
|
|
|
lr_base: float = 0.1
|
|
|
lr_gamma: int = 0.1
|
|
@@ -44,7 +47,7 @@ class TrainingArguments:
|
|
|
winddown_time: float = 5.0
|
|
|
verbose: bool = True
|
|
|
|
|
|
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
+ device: str = "cpu"
|
|
|
make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
|
|
|
make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
|
|
|
nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
|
|
@@ -74,6 +77,7 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
optimizer = Optimizer(
|
|
|
prefix=args.prefix,
|
|
|
target_batch_size=args.target_batch_size,
|
|
|
+ batch_size_per_step=batch_size,
|
|
|
params=model.parameters(),
|
|
|
optimizer=partial(torch.optim.SGD, lr=args.lr_base),
|
|
|
scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
|
|
@@ -82,22 +86,39 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
matchmaking_time=args.matchmaking_time,
|
|
|
averaging_timeout=args.averaging_timeout,
|
|
|
reuse_grad_buffers=args.reuse_grad_buffers,
|
|
|
+ delay_optimizer_step=args.delay_optimizer_step,
|
|
|
client_mode=client_mode,
|
|
|
verbose=verbose,
|
|
|
)
|
|
|
|
|
|
+ if args.reuse_grad_buffers:
|
|
|
+ grad_scaler = hivemind.GradScaler()
|
|
|
+ else:
|
|
|
+ grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
|
|
|
+
|
|
|
prev_time = time.perf_counter()
|
|
|
|
|
|
while optimizer.local_epoch < args.max_epoch:
|
|
|
time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
|
|
|
|
|
|
batch = torch.randint(0, len(X_train), (batch_size,))
|
|
|
- loss = F.cross_entropy(model(X_train[batch]), y_train[batch])
|
|
|
- loss.backward()
|
|
|
|
|
|
- optimizer.step(batch_size=batch_size)
|
|
|
+ with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
|
|
|
+ loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
|
|
|
+ grad_scaler.scale(loss).backward()
|
|
|
+
|
|
|
+ grad_scaler.unscale_(optimizer)
|
|
|
+
|
|
|
+ if args.use_amp:
|
|
|
+ grad_scaler.step(optimizer)
|
|
|
+ else:
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ grad_scaler.update()
|
|
|
+
|
|
|
if not args.reuse_grad_buffers:
|
|
|
optimizer.zero_grad()
|
|
|
+
|
|
|
prev_time = time.perf_counter()
|
|
|
|
|
|
time.sleep(args.winddown_time)
|
|
@@ -112,6 +133,7 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
mp.Process(
|
|
|
target=run_trainer,
|
|
|
name=f"trainer-{index}",
|
|
|
+ daemon=False,
|
|
|
kwargs=dict(
|
|
|
batch_size=batch_size,
|
|
|
batch_time=batch_time,
|
|
@@ -121,8 +143,12 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- for peer in peers[1:]:
|
|
|
- peer.start()
|
|
|
- peers[0].run()
|
|
|
- for peer in peers[1:]:
|
|
|
- peer.join()
|
|
|
+ try:
|
|
|
+ for peer in peers[1:]:
|
|
|
+ peer.start()
|
|
|
+ peers[0].run()
|
|
|
+ for peer in peers[1:]:
|
|
|
+ peer.join()
|
|
|
+ finally:
|
|
|
+ for peer in peers[1:]:
|
|
|
+ peer.kill()
|