inference_chain.py 909 B

1234567891011121314151617181920212223242526272829
  1. from collections import defaultdict
  2. from typing import Sequence
  3. import torch
  4. from hivemind import DHT
  5. from torch import nn
  6. from src import DistributedBloomConfig
  7. from src.server.backend import MAX_LENGTH
  8. class RemoteInferenceChain(nn.Module):
  9. """An auxiliary class that manages distributed inference in a chain of one or more remote transformer modules"""
  10. def __init__(self, dht: DHT, config: DistributedBloomConfig, block_names: Sequence[str]):
  11. super().__init__()
  12. self.dht = dht
  13. self.config, self.block_names = config, block_names
  14. self.block_caches = {name: torch.zeros(1, MAX_LENGTH, config.hidden_size) for name in block_names}
  15. self.current_position = 0
  16. def step(self, hidden_states: torch.Tensor):
  17. pass
  18. # plan:
  19. # - run inference STUB from a jupyter notebook
  20. # - extend to run actual inference
  21. # - extend to run multiple layers at a time