|
@@ -21,7 +21,7 @@ from hivemind.utils.crypto import RSAPrivateKey
|
|
|
@dataclass(frozen=True)
|
|
|
class TrainingArguments:
|
|
|
seed: int = 42
|
|
|
- prefix: str = "my_exp"
|
|
|
+ run_id: str = "my_exp"
|
|
|
|
|
|
num_peers: int = 8
|
|
|
num_clients: int = 3
|
|
@@ -77,7 +77,7 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
|
|
|
|
|
|
optimizer = Optimizer(
|
|
|
- run_id=args.prefix,
|
|
|
+ run_id=args.run_id,
|
|
|
target_batch_size=args.target_batch_size,
|
|
|
batch_size_per_step=batch_size,
|
|
|
params=model.parameters(),
|