|
@@ -168,12 +168,17 @@ def test_load_state_from_peers(dpu: bool):
|
|
|
)
|
|
|
|
|
|
avgr1 = TrainingStateAverager(
|
|
|
- dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, extra_tensors=extras1,
|
|
|
- **common_kwargs
|
|
|
+ dht=dht1,
|
|
|
+ params=model1.parameters(),
|
|
|
+ allow_state_sharing=False,
|
|
|
+ start=True,
|
|
|
+ extra_tensors=extras1,
|
|
|
+ **common_kwargs,
|
|
|
)
|
|
|
|
|
|
- avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, extra_tensors=extras2,
|
|
|
- **common_kwargs)
|
|
|
+ avgr2 = TrainingStateAverager(
|
|
|
+ dht=dht2, params=model2.parameters(), start=True, extra_tensors=extras2, **common_kwargs
|
|
|
+ )
|
|
|
|
|
|
avgr2.local_epoch = 1337
|
|
|
model2.weight.data[...] = 42
|