|
@@ -8,7 +8,7 @@ from hivemind.proto import runtime_pb2
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.utils.threading import run_in_background
|
|
from hivemind.utils.threading import run_in_background
|
|
|
|
|
|
-FP16_MAX = 65_504
|
|
|
|
|
|
+FP32_EPS = 1e-06
|
|
NUM_BYTES_FLOAT32 = 4
|
|
NUM_BYTES_FLOAT32 = 4
|
|
NUM_BYTES_FLOAT16 = 2
|
|
NUM_BYTES_FLOAT16 = 2
|
|
NUM_BITS_QUANTILE_COMPRESSION = 8
|
|
NUM_BITS_QUANTILE_COMPRESSION = 8
|
|
@@ -86,6 +86,7 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
|
|
tensor.sub_(means)
|
|
tensor.sub_(means)
|
|
|
|
|
|
stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
|
|
stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
|
|
|
|
+ stds.clamp_min_(FP32_EPS)
|
|
tensor.div_(stds)
|
|
tensor.div_(stds)
|
|
tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
|
|
tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
|
|
|
|
|