瀏覽代碼

actually unscale

justheuristic 3 年之前
父節點
當前提交
80ced17866
共有 1 個文件被更改,包括 35 次插入9 次删除
  1. 35 9
      benchmarks/benchmark_optimizer.py

+ 35 - 9
benchmarks/benchmark_optimizer.py

@@ -15,6 +15,7 @@ from torch.utils.data import Dataset
 import hivemind
 from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.utils.crypto import RSAPrivateKey
+from contextlib import nullcontext
 
 
 @dataclass(frozen=True)
@@ -26,6 +27,8 @@ class TrainingArguments:
     num_clients: int = 3
     target_batch_size: int = 128
     reuse_grad_buffers: bool = True
+    delay_optimizer_step: bool = False
+    use_amp: bool = True
 
     lr_base: float = 0.1
     lr_gamma: int = 0.1
@@ -44,7 +47,7 @@ class TrainingArguments:
     winddown_time: float = 5.0
     verbose: bool = True
 
-    device = "cuda:0" if torch.cuda.is_available() else "cpu"
+    device: str = "cpu"
     make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
     make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
         nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
@@ -74,6 +77,7 @@ def benchmark_optimizer(args: TrainingArguments):
         optimizer = Optimizer(
             prefix=args.prefix,
             target_batch_size=args.target_batch_size,
+            batch_size_per_step=batch_size,
             params=model.parameters(),
             optimizer=partial(torch.optim.SGD, lr=args.lr_base),
             scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
@@ -82,22 +86,39 @@ def benchmark_optimizer(args: TrainingArguments):
             matchmaking_time=args.matchmaking_time,
             averaging_timeout=args.averaging_timeout,
             reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_optimizer_step=args.delay_optimizer_step,
             client_mode=client_mode,
             verbose=verbose,
         )
 
+        if args.reuse_grad_buffers:
+            grad_scaler = hivemind.GradScaler()
+        else:
+            grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
+
         prev_time = time.perf_counter()
 
         while optimizer.local_epoch < args.max_epoch:
             time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
 
             batch = torch.randint(0, len(X_train), (batch_size,))
-            loss = F.cross_entropy(model(X_train[batch]), y_train[batch])
-            loss.backward()
 
-            optimizer.step(batch_size=batch_size)
+            with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
+                loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
+                grad_scaler.scale(loss).backward()
+
+            grad_scaler.unscale_(optimizer)
+
+            if args.use_amp:
+                grad_scaler.step(optimizer)
+            else:
+                optimizer.step()
+
+            grad_scaler.update()
+
             if not args.reuse_grad_buffers:
                 optimizer.zero_grad()
+
             prev_time = time.perf_counter()
 
         time.sleep(args.winddown_time)
@@ -112,6 +133,7 @@ def benchmark_optimizer(args: TrainingArguments):
             mp.Process(
                 target=run_trainer,
                 name=f"trainer-{index}",
+                daemon=False,
                 kwargs=dict(
                     batch_size=batch_size,
                     batch_time=batch_time,
@@ -121,8 +143,12 @@ def benchmark_optimizer(args: TrainingArguments):
             )
         )
 
-    for peer in peers[1:]:
-        peer.start()
-    peers[0].run()
-    for peer in peers[1:]:
-        peer.join()
+    try:
+        for peer in peers[1:]:
+            peer.start()
+        peers[0].run()
+        for peer in peers[1:]:
+            peer.join()
+    finally:
+        for peer in peers[1:]:
+            peer.kill()