__init__.py 786 B

12345678910111213141516171819202122232425
  1. from socket import socket
  2. import torch
  3. def print_device_info(device=None):
  4. # prints device stats. Code from https://stackoverflow.com/a/53374933/12891528
  5. device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
  6. print('Using device:', device)
  7. # Additional Info when using cuda
  8. if device.type == 'cuda':
  9. print(torch.cuda.get_device_name(0))
  10. print('Memory Usage:')
  11. print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
  12. print('Cached: ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
  13. def find_open_port():
  14. try:
  15. sock = socket()
  16. sock.bind(('', 0))
  17. return sock.getsockname()[1]
  18. except:
  19. raise ValueError("Could not find open port")