1234567891011121314151617181920212223242526272829 |
- from collections import defaultdict
- from typing import Sequence
- import torch
- from hivemind import DHT
- from torch import nn
- from src import DistributedBloomConfig
- from src.server.backend import MAX_LENGTH
- class RemoteInferenceChain(nn.Module):
- """An auxiliary class that manages distributed inference in a chain of one or more remote transformer modules"""
- def __init__(self, dht: DHT, config: DistributedBloomConfig, block_names: Sequence[str]):
- super().__init__()
- self.dht = dht
- self.config, self.block_names = config, block_names
- self.block_caches = {name: torch.zeros(1, MAX_LENGTH, config.hidden_size) for name in block_names}
- self.current_position = 0
- def step(self, hidden_states: torch.Tensor):
- pass
- # plan:
- # - run inference STUB from a jupyter notebook
- # - extend to run actual inference
- # - extend to run multiple layers at a time
|