Max Ryabinin пре 3 година
родитељ
комит
9b5ee08bd6

+ 1 - 0
hivemind/hivemind_cli/run_server.py

@@ -47,6 +47,7 @@ def main():
                         help='Target group size for decentralized averaging')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
+    parser.add_argument('--fp16',action='store_true',help='Use mixed precision during forward and backward steps')
 
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
     parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',

+ 2 - 0
hivemind/moe/server/__init__.py

@@ -119,6 +119,7 @@ class Server(threading.Thread):
         averaging_timeout=30,
         reuse_grad_buffers=True,
         device=None,
+        fp16=False,
         no_dht=False,
         dht_port=None,
         dht_listen_on=None,
@@ -314,6 +315,7 @@ class Server(threading.Thread):
                 args_schema=args_schema,
                 optimizer=optim,
                 device=device,
+                fp16=fp16,
                 clip_grad_norm=clip_grad_norm,
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,

+ 5 - 2
hivemind/moe/server/expert_backend.py

@@ -1,4 +1,5 @@
 from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
+from contextlib import nullcontext
 
 import torch
 from torch import nn
@@ -48,6 +49,7 @@ class ExpertBackend:
         optimizer: torch.optim.Optimizer,
         *,
         device: torch.device,
+        fp16: bool = False,
         scheduler: Callable = None,
         args_schema: Tuple[BatchTensorDescriptor, ...] = None,
         kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
@@ -61,6 +63,7 @@ class ExpertBackend:
         self.expert = expert.to(device)
         self.optimizer, self.name = optimizer, name
         self.device = device
+        self.fp16 = fp16
 
         if scheduler is None:
             self.scheduler = None
@@ -115,7 +118,7 @@ class ExpertBackend:
         if args[0].shape[0] == 0:
             raise RuntimeError("Batch should contain more than 0 samples")
 
-        with torch.no_grad():
+        with torch.no_grad(), torch.cuda.amp.autocast() if self.fp16 else nullcontext():
             outputs = self.expert(*args, **kwargs)
 
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
@@ -140,7 +143,7 @@ class ExpertBackend:
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
-        with torch.enable_grad():
+        with torch.enable_grad(), torch.cuda.amp.autocast() if self.fp16 else nullcontext():
             args = [
                 tensor.detach().requires_grad_(True)
                 if tensor.dtype in (torch.half, torch.float, torch.double)