Aleksandr Borzunov 3 éve
szülő
commit
2816c5c09a
4 módosított fájl, 11 hozzáadás és 10 törlés
  1. 3 3
      arguments.py
  2. 7 3
      manage_scaleset.py
  3. 0 2
      run_trainer.py
  4. 1 2
      task.py

+ 3 - 3
arguments.py

@@ -64,13 +64,13 @@ class CollaborativeArguments:
         metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
     )
     matchmaking_time: float = field(
-        default=30.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
+        default=15.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     allreduce_timeout: float = field(
-        default=60, metadata={"help": "Give up on a given all-reduce round after this many seconds"}
+        default=120, metadata={"help": "Give up on a given all-reduce round after this many seconds"}
     )
     averaging_timeout: float = field(
-        default=180, metadata={"help": "Give up on averaging step after this many seconds"}
+        default=120, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     reuse_grad_buffers: bool = field(default=True, metadata={
         "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This "

+ 7 - 3
manage_scaleset.py

@@ -7,15 +7,19 @@ from azure.mgmt.compute import ComputeManagementClient
 from azure.mgmt.network import NetworkManagementClient
 from azure.mgmt.resource import ResourceManagementClient
 
+
+print("=======================WARNING=======================")
+print("= The code may fail to import 'gi' but that is okay =")
+print("===================END OF WARNING====================")
 SUBSCRIPTION_ID = os.environ["SUBSCRIPTION_ID"]
-GROUP_NAME = "dalle_northeu"
+GROUP_NAME = "dalle_west2"
 NETWORK_NAME = "vnet"
 SUBNET_NAME = "subnet"
-LOCATION = "northeurope"
+LOCATION = "westus2"
 ADMIN_PASS = os.environ['AZURE_PASS']
 
 SCALE_SETS = ('worker',)
-SWARM_SIZE = 64
+SWARM_SIZE = 4
 
 WORKER_CLOUD_INIT = """#cloud-config
 package_update: true

+ 0 - 2
run_trainer.py

@@ -28,8 +28,6 @@ def main():
     training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
 
     logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
-    # if len(training_peer_args.initial_peers) == 0:
-    #     logger.warning("Please specify at least one network endpoint in initial peers.")
 
     utils.log_process_rank(trainer_args)
     task = TrainingTask(training_peer_args, trainer_args, collab_args)

+ 1 - 2
task.py

@@ -121,8 +121,7 @@ class TrainingTask:
             self._collaborative_optimizer = hivemind.Optimizer(
                 dht=self.dht, run_id=self.peer_args.experiment_prefix,
                 params=params, optimizer=opt, scheduler=scheduler,
-                offload_optimizer=True,
-                delay_grad_averaging=False, delay_optimizer_step=True,
+                offload_optimizer=True, delay_grad_averaging=False, delay_optimizer_step=True,
                 batch_size_per_step=self.trainer_args.batch_size_per_step,
                 grad_compression=averaging_compression, state_averaging_compression=averaging_compression,
                 client_mode=self.peer_args.client_mode, verbose=True,