|
@@ -210,5 +210,7 @@ def test_adaptive_compression():
|
|
assert FLOAT32.mp_part_size.value == 1250 # four-byte tensors
|
|
assert FLOAT32.mp_part_size.value == 1250 # four-byte tensors
|
|
|
|
|
|
averager1.load_state_from_peers()
|
|
averager1.load_state_from_peers()
|
|
- assert STATE_FP16.mp_counter.value == STATE_FP32.mp_counter.value == 9
|
|
|
|
|
|
+ state_metadata, state_tensors, infos = averager1.get_current_state()
|
|
|
|
+ assert STATE_FP16.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() >= 500])
|
|
|
|
+ assert STATE_FP32.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() < 500])
|
|
assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0 # not partitioned
|
|
assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0 # not partitioned
|