math.py 848 B

123456789101112131415161718192021222324
  1. import torch
  2. import torch.nn.functional as F
  3. @torch.jit.script
  4. def orthogonalize_(matrix, eps: float = 1e-8):
  5. """Orthogonalize a 2d tensor in-place over the last dimension"""
  6. n, m = matrix.shape
  7. for i in range(m):
  8. col = matrix[:, i]
  9. F.normalize(col, dim=0, eps=eps, out=col)
  10. if i + 1 < m:
  11. rest = matrix[:, i + 1 :]
  12. rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)
  13. def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
  14. """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
  15. dims = list(tensor.shape)
  16. while len(dims) > max_ndim:
  17. squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
  18. squeezed_dim = dims.pop(squeeze_ix)
  19. dims[squeeze_ix] *= squeezed_dim
  20. return dims