|
@@ -1,12 +1,15 @@
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
+from pydantic import BaseModel, StrictFloat, conint
|
|
|
+
|
|
|
import time
|
|
|
from threading import Thread, Lock, Event
|
|
|
-from typing import Optional, Sequence, Tuple
|
|
|
+from typing import Optional, Sequence, Tuple, Dict
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from hivemind.dht import DHT
|
|
|
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
|
|
|
from hivemind.client import TrainingAverager
|
|
|
from hivemind.optim.base import DecentralizedOptimizerBase
|
|
|
from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
|
|
@@ -19,6 +22,14 @@ class TrainingState:
|
|
|
max_epoch: int = 0
|
|
|
total_steps: int = 0
|
|
|
|
|
|
+class WorkerProgress(BaseModel):
|
|
|
+ time: StrictFloat
|
|
|
+ epoch: conint(ge=0, strict=True)
|
|
|
+ step: conint(ge=0, strict=True)
|
|
|
+
|
|
|
+class WorkerProgressSchema(BaseModel):
|
|
|
+ progress: Dict[BytesWithPublicKey, Optional[WorkerProgress]]
|
|
|
+
|
|
|
class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
"""
|
|
|
A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
|
|
@@ -75,6 +86,9 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.report_progress_event, self.fetch_training_state_event = Event(), Event()
|
|
|
self.lock_scheduler_params = Lock()
|
|
|
self.training_state = TrainingState(max_epoch=0, total_steps=0)
|
|
|
+
|
|
|
+ self.dht.add_validators([SchemaValidator(WorkerProgressSchema, prefix=self.report_progress_key)])
|
|
|
+
|
|
|
self._fetch_training_state()
|
|
|
self._sync_if_needed()
|
|
|
|
|
@@ -184,8 +198,8 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
break
|
|
|
current_time = get_dht_time()
|
|
|
with self.lock_scheduler_params:
|
|
|
- local_state_info = [self.local_step, current_time, self.local_epoch]
|
|
|
- self.dht.store(key=self.report_progress_key, subkey=self.averager.endpoint, value=local_state_info,
|
|
|
+ local_worker_progress = WorkerProgress(step=self.local_step, time=current_time, epoch=self.local_epoch)
|
|
|
+ self.dht.store(key=self.report_progress_key, subkey=self.averager.endpoint, value=local_worker_progress.dict(),
|
|
|
expiration_time=current_time + self.report_progress_expiration, return_future=False)
|
|
|
|
|
|
@torch.no_grad()
|
|
@@ -208,16 +222,17 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.training_state = TrainingState(max_epoch=self.local_epoch, total_steps=self.local_step)
|
|
|
return
|
|
|
|
|
|
- valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
|
|
|
+ valid_peer_states = [WorkerProgress.parse_obj(peer_state.value) for peer_state in response.values()
|
|
|
+ if isinstance(peer_state, ValueWithExpiration) and peer_state.value is not None]
|
|
|
num_peers = len(valid_peer_states)
|
|
|
with self.lock_scheduler_params:
|
|
|
global_epoch = self.local_epoch
|
|
|
- for step, time, epoch in valid_peer_states:
|
|
|
- global_epoch = max(global_epoch, epoch)
|
|
|
+ for state in valid_peer_states:
|
|
|
+ global_epoch = max(global_epoch, state.epoch)
|
|
|
total_steps = 0
|
|
|
- for step, time, epoch in valid_peer_states:
|
|
|
- if epoch == global_epoch:
|
|
|
- total_steps += step
|
|
|
+ for state in valid_peer_states:
|
|
|
+ if state.epoch == global_epoch:
|
|
|
+ total_steps += state.step
|
|
|
self.training_state = TrainingState(max_epoch=global_epoch, total_steps=total_steps)
|
|
|
|
|
|
|