justheuristic преди 3 години
родител
ревизия
477da687f9

+ 13 - 90
README.md

@@ -1,99 +1,22 @@
-# bloom-demo
-Early dev prototype for decentralized bloom. Not for public eyes **yet**.
-
-Roadmap: [issue #12](https://github.com/learning-at-home/bloom-demo/issues/12)
-
-Latest news @ main branch (max 5):
-- [Jul 1] @yozh added RemoteSequential and test for full model exact match
-- [June 28] @dbaranchunk added quick deployment scripts for testnet
-
-### install
+[
+Based on assorted code by shuf(mryab@ younesbelkada@ borzunov@ timdettmers@ dbaranchuk@ greenfatguy@ artek0chumak@)
+]
 
 
+# Install
 ```bash
-conda create -y --name bloom-demo python=3.8.12 pip
-conda activate bloom-demo
 
-conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-pip install accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-pip install bitsandbytes-cuda113==0.26.0
-pip install https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
-```
+git clone https://github.com/CompVis/latent-diffusion.git
+git clone https://github.com/CompVis/taming-transformers
+pip install -e ./taming-transformers
+pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops
+mkdir -p models/ldm/cin256-v2/
+wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt 
 
-
-### run local inference:
-No networking whatsoever, used to verify architecture optimizations
-
-```bash
-# run one bloom block for a few steps -- on a local machine
-python -m cli.inference_one_block --config cli/config.json  # see other args
 ```
 
-### run distributed inference / training
 
-First, run one or more servers like this:
-```bash
-# minimalistic server with non-trained bloom blocks
-python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigscience/test-bloomd-6b3 \
-  --block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
-# when running multiple servers:
-# - give each server a unique --identity_path (or remote --identity_path arg when debugging)
-# - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
-# - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
-# - each server except first should have --initial_peers pointing to one of pre-existing servers 
-```
-
-Then open a python notebook or console and run:
 ```python
-import torch
-import hivemind
-from src import get_remote_module
-
-
-dht = hivemind.DHT(
-    initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS],  # e.g. /ip4/127.0.0.1/...
-    client_mode=True, start=True,
-)
-
-layer3, layer4 = get_remote_module(dht, ['bloom6b3.3', 'bloom6b3.4'])
-assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
-# test forward/backward, two blocks
-outputs, = layer4(*layer3(torch.randn(1, 64, 4096)))
-loss = (outputs * torch.randn_like(outputs)).norm()
-loss.backward()
-
-# test inference, one block
-with layer3.begin_inference_session() as sess:
-    for i in range(10):
-        res = sess.step(torch.ones(1, 1, 4096))
-```
-
-
-### convert regular bloom to distributed
-```bash
-
-# convert model from HF hub to a distributed format (can take hours depending on your connection!)
-MY_WRITE_TOKEN=TODO_WRITE_TOKEN_FROM_https://huggingface.co/settings/token
-python -m cli.convert_model --model bigscience/bloom-6b3  \
-  --output_path ./converted_model --output_repo bigscience/test-bloomd-6b3 \
-  --use_auth_token $MY_WRITE_TOKEN  # ^-- todo replace output repo with something you have access to
-```
-
-
-### test local vs remote block (allclose)
-
-To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables:
-```bash
-# shell A: serve blocks 3 and 4
-python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigscience/test-bloomd-6b3 \
-  --block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
-
-# shell B: connect to the swarm and test individual blocks for exact match
-export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
-BLOCK_UID=bloom6b3.3 pytest tests/test_block_exact_match.py
-BLOCK_UID=bloom6b3.4 pytest tests/test_block_exact_match.py
-
-# the test below will fail because there is no server that serves layer 7
-# BLOCK_UID=bloom6b3.7 pytest tests/test_block_exact_match.py
-```
+hivemind-server --custom_module_path ./your_code_here.py --expert_cls ExampleModule --hidden_dim 512 --num_experts 1 \
+    --expert_pattern "expert.0.[0:9999]" --identity server1.id
+```

+ 0 - 0
cli/__init__.py


+ 0 - 20
cli/config.json

@@ -1,20 +0,0 @@
-{
-  "apply_residual_connection_post_layernorm": false,
-  "attention_dropout": 0.0,
-  "attention_softmax_in_fp32": true,
-  "bos_token_id": 1,
-  "eos_token_id": 2,
-  "hidden_dropout": 0.0,
-  "initializer_range": 0.02,
-  "layer_norm_epsilon": 1e-05,
-  "masked_softmax_fusion": true,
-  "model_type": "bloom",
-  "n_embed": 14336,
-  "n_layer": 70,
-  "num_attention_heads": 112,
-  "pretraining_tp": 4,
-  "slow_but_exact": false,
-  "transformers_version": "4.20.0.dev0",
-  "use_cache": true,
-  "vocab_size": 250880
-}

+ 0 - 87
cli/convert_model.py

@@ -1,87 +0,0 @@
-import argparse
-import os
-
-import psutil
-import torch.backends.quantized
-import torch.nn as nn
-import transformers
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from huggingface_hub import Repository
-from tqdm.auto import tqdm
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
-
-    parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
-    parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
-    parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
-    parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
-    parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
-    parser.add_argument("--base_branch", type=str, default="main", help="Use this branch as reference point")
-    parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch")
-    parser.add_argument(
-        "--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix"
-    )
-    parser.add_argument(
-        "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
-    )
-    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
-    args = parser.parse_args()
-
-    free_ram_gb = psutil.virtual_memory().available / 2**30
-    if args.model == "bigscience/bloom" and free_ram_gb < 400:
-        logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
-
-    assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
-    if os.path.exists(args.output_path) and (
-        len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
-    ):
-        raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
-
-    logger.info(f"Loading source model {args.model} (this may take a few minutes)")
-    config = transformers.AutoConfig.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision
-    )
-    model = transformers.AutoModel.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
-    )
-    tokenizer = transformers.AutoTokenizer.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision
-    )
-    os.makedirs(args.output_path, exist_ok=True)
-
-    repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
-    repo.git_pull()
-
-    transformer_blocks = model.h
-    logger.info(
-        f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
-        f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
-    )
-    for i, block in enumerate(tqdm(transformer_blocks)):
-        repo.git_checkout(args.base_branch, create_branch_ok=True)
-        with repo.commit(
-            commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
-        ):
-            torch.save(block.state_dict(), "./pytorch_model.bin")
-
-    logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
-    repo.git_checkout(args.base_branch, create_branch_ok=True)
-    with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
-        model.h = nn.ModuleList()
-        model.save_pretrained(".")
-
-    logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")
-
-    repo.git_checkout(args.base_branch, create_branch_ok=True)
-    with repo.commit(commit_message=args.commit_message, branch=args.base_branch, track_large_files=True):
-        tokenizer.save_pretrained(".")
-        config.save_pretrained(".")
-
-    logger.info(f"Converted {args.model} and pushed to {args.output_repo}")

+ 0 - 87
cli/deploy_server.sh

@@ -1,87 +0,0 @@
-#!/usr/bin/env bash
-
-#################
-# Parse options #
-#################
-
-instructions() {
-  echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
-  echo " -i: initial peer"
-  echo " -d: device" >&2
-  echo " -p: server identity path" >&2
-  echo " -b: block_ids" >&2
-  echo " -a: host maddrs" >&2
-  echo " -t: whether to run local tests" >&2
-  exit 1
-}
-
-if [ ! $# -ge 8 ]; then
-    instructions
-fi
-
-while getopts ":i:d:p:b:a:t:" option; do
-    case $option in
-        i)  INITIAL_PEER=${OPTARG}
-            ;;
-        d)  DEVICE=${OPTARG}
-            ;;
-        p)  SERVER_ID_PATH=${OPTARG}
-            ;;
-        b)  BLOCK_IDS=${OPTARG}
-            ;;
-        a)  HOST_MADDR=${OPTARG} # TODO: allow several maddrs 
-            ;;
-        t)  RUN_LOCAL_TESTS=true
-            ;;
-        \?) instructions
-            ;;
-   esac
-done
-
-
-echo "=========="
-echo "= Config ="
-echo "=========="
-echo "Initial peer: ${INITIAL_PEER}"
-echo "Device: ${DEVICE}"
-echo "Server name: ${SERVER_ID_PATH}"
-echo "Server address: ${HOST_MADDR}"
-echo "Bloom blocks: ${BLOCK_IDS}"
-
-
-###########################
-# Install or activate env #
-###########################
-
-# TODO fix bug with self calling
-source ~/miniconda3/etc/profile.d/conda.sh
-if conda env list | grep ".*bloom-demo.*"  >/dev/null 2>/dev/null; then
-    conda activate bloom-demo
-else
-    conda create -y --name bloom-demo python=3.8.12 pip
-    conda activate bloom-demo
-
-    conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
-fi
-
-
-##############
-# Local test #
-##############
-
-if [ "$RUN_LOCAL_TESTS" = true ] ; then
-    echo "Run test on your local machine"
-    python -m cli.inference_one_block --config cli/config.json --device ${DEVICE} # see other args
-fi
-
-
-##############
-# Run server #
-##############
-
-python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
-  --block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log

+ 0 - 53
cli/inference_one_block.py

@@ -1,53 +0,0 @@
-import argparse
-
-import torch
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from tqdm.auto import trange
-
-from src.bloom.block import BloomBlock
-from src.bloom.model import DistributedBloomConfig
-from src.bloom.ops import build_alibi_tensor
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-logger.warning("inference_one_block will soon be deprecated in favour of tests!")
-
-
-def print_device_info(device=None):
-    """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
-    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
-    logger.info(f"Using device: {device}")
-
-    # Additional Info when using cuda
-    if device.type == "cuda":
-        logger.info(torch.cuda.get_device_name(0))
-        logger.info(f"Memory Usage:")
-        logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
-        logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
-    parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
-    parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
-    parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
-    parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
-    parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
-    args = parser.parse_args()
-
-    if args.device is None:
-        args.device = "cuda" if torch.cuda.is_available() else "cpu"
-
-    config = DistributedBloomConfig.from_json_file(args.config)
-    block = BloomBlock(config, args.layer_index).to(args.device)
-
-    cache = None
-
-    for i in trange(args.num_steps):
-        dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
-        alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
-        with torch.no_grad():
-            outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
-
-    print_device_info(args.device)

+ 0 - 5
cli/local_server_config_example.cfg

@@ -1,5 +0,0 @@
-device=cpu
-block_ids=2:3
-id_path=./server.id
-maddr=/ip4/127.0.0.1/tcp/30000
-#

+ 0 - 6
cli/remote_server_config_example.cfg

@@ -1,6 +0,0 @@
-name=bloom-peer-0.bloom.net
-device=cpu
-block_ids=1:3
-id_path=./server.id
-maddr=/ip4/0.0.0.0/tcp/30000
-#

+ 0 - 111
cli/run_local_servers.sh

@@ -1,111 +0,0 @@
-# !/usr/bin/env bash
-
-#################
-# Parse options #
-#################
-
-instructions() {
-  echo "Usage: $0 [-n] [-c]" >&2
-  echo " -n: number of servers to run" >&2
-  echo " -c: path to the server configs" >&2
-  exit 1
-}
-
-if [ $# != 4 ]; then
-    instructions
-fi
-
-while getopts ":n:c:t:" option; do
-    case $option in
-        n)  NUM_SERVERS=${OPTARG}
-            ;;
-        c)  CONFIG_PATH=${OPTARG}
-            ;;
-        \?) instructions
-            ;;
-   esac
-done
-
-
-###########################
-# Install or activate env #
-###########################
-
-source ~/miniconda3/etc/profile.d/conda.sh
-if conda env list | grep ".*bloom-demo.*"  &>/dev/null; then
-    conda activate bloom-demo
-else
-    conda create -y --name bloom-demo python=3.8.12 pip
-    conda activate bloom-demo
-
-    conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
-fi
-
-
-#######################
-# Create Initial peer #
-#######################
-
-hivemind-dht &> tmp.out &
-sleep 3
-INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
-echo "Initial peer: ${INITIAL_PEER}"
-
-
-##############################
-# Initialize the config file #
-##############################
-
-typeset -A cfg 
-cfg=( # set default values in config array
-    [device]="cpu"
-    [block_ids]="1:2"
-    [id_path]="server.id"
-    [maddr]="/ip4/127.0.0.1/tcp/30000"
-)
-
-###############
-# Run servers #
-###############
-
-for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
-do  
-    ###############
-    # Read config #
-    ###############
-
-    while read line
-    do
-        if echo $line | grep -F = &>/dev/null
-        then
-            varname=$(echo "$line" | cut -d '=' -f 1)
-            cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
-        fi
-    done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
-    
-    echo "=== Server #${SERVER_ID} ==="
-    echo "Server ID: ${id_path}"
-    echo "Device: ${cfg[device]}"
-    echo "Bloom block ids: ${cfg[block_ids]}"
-    echo "Host maddr: ${cfg[maddr]}"
-    echo ""
-    
-    ##############
-    # Run server #
-    ##############
-
-    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
-done
-
-
-#####################
-# Kill initial peer #
-#####################
-
-sleep 10
-pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
-rm tmp.out

+ 0 - 112
cli/run_remote_servers.sh

@@ -1,112 +0,0 @@
-# !/usr/bin/env bash
-
-SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
-
-#################
-# Parse options #
-#################
-
-instructions() {
-  echo "Usage: $0 [-u] [-n] [-c]" >&2
-  echo " -u: username" >&2
-  echo " -n: number of servers to run" >&2
-  echo " -c: path to the server configs" >&2
-  exit 1
-}
-
-if [ $# != 6 ]; then
-    instructions
-fi
-
-while getopts ":u:n:c:" option; do
-    case $option in
-        u)  USERNAME=${OPTARG}
-            ;;
-        n)  NUM_SERVERS=${OPTARG}
-            ;;
-        c)  CONFIG_PATH=${OPTARG}
-            ;;
-        \?) instructions
-            ;;
-   esac
-done
-
-
-###########################
-# Install or activate env #
-###########################
-
-source ~/miniconda3/etc/profile.d/conda.sh
-if conda env list | grep ".*bloom-demo.*"  &>/dev/null; then
-    conda activate bloom-demo
-else
-    conda create -y --name bloom-demo python=3.8.12 pip
-    conda activate bloom-demo
-
-    conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-    pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
-    pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
-    pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
-    pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
-fi
-
-
-#######################
-# Create Initial peer #
-#######################
-
-hivemind-dht &> tmp.out &
-
-sleep 3
-INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
-rm tmp.out
-echo "Initial peer: ${INITIAL_PEER}"
-
-
-##############################
-# Initialize the config file #
-##############################
-
-typeset -A cfg 
-cfg=( # set default values in config array
-    [name]=""
-    [device]="cpu"
-    [block_ids]="1:2"
-    [id_path]="server.id"
-    [maddr]="/ip4/0.0.0.0/tcp/30000"
-)
-
-###############
-# Run servers #
-###############
-
-for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
-do  
-    ###############
-    # Read config #
-    ###############
-
-    while read line
-    do
-        if echo $line | grep -F = &>/dev/null
-        then
-            varname=$(echo "$line" | cut -d '=' -f 1)
-            cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
-        fi
-    done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
-    
-    SERVER_NAME="${USERNAME}@${cfg[name]}"
-    echo "=== Server #${SERVER_ID} ==="
-    echo "Server name ${SERVER_NAME}"
-    echo "Server ID: ${cfg[id_path]}"
-    echo "Device: ${cfg[device]}"
-    echo "Bloom block ids: ${cfg[block_ids]}"
-    echo "Host maddr: ${cfg[maddr]}"
-    echo "================="
-    
-    ##############
-    # Run server #
-    ##############
-     
-    ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
-done

+ 0 - 85
cli/run_server.py

@@ -1,85 +0,0 @@
-import configargparse
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
-from src.server.server import Server
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-def main():
-    # fmt:off
-    parser = configargparse.ArgParser(default_config_files=["config.yml"])
-    parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
-
-    parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
-                        help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
-    parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
-    parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
-    parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
-                                                                 "use the same name as in the converted model.")
-    parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
-                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
-    parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
-                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
-
-    parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
-
-    parser.add_argument('--num_handlers', type=int, default=None, required=False,
-                        help='server will use this many processes to handle incoming requests')
-    parser.add_argument('--min_batch_size', type=int, default=1,
-                        help='Minimum required batch size for all expert operations')
-    parser.add_argument('--max_batch_size', type=int, default=16384,
-                        help='The total number of examples in the same batch will not exceed this value')
-    parser.add_argument('--cache_size_bytes', type=int, default=None,
-                        help='The size of memory cache for storing past attention keys/values between inference steps')
-    parser.add_argument('--device', type=str, default=None, required=False,
-                        help='all experts will use this device in torch notation; default: cuda if available else cpu')
-    parser.add_argument("--torch_dtype", type=str, default="auto",
-                        help="Use this dtype to store block weights and do computations. "
-                             "By default, respect the dtypes in the pre-trained state dict.")
-
-    parser.add_argument('--update_period', type=float, required=False, default=30,
-                        help='Server will report experts to DHT once in this many seconds')
-    parser.add_argument('--expiration', type=float, required=False, default=None,
-                        help='DHT entries will expire after this many seconds')
-    parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
-                        help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
-    parser.add_argument('--increase_file_limit', action='store_true',
-                        help='On *nix, this will increase the max number of processes '
-                             'a server can spawn before hitting "Too many open files"; Use at your own risk.')
-    parser.add_argument('--stats_report_interval', type=int, required=False,
-                        help='Interval between two reports of batch processing performance statistics')
-
-    parser.add_argument('--custom_module_path', type=str, required=False,
-                        help='Path of a file with custom nn.modules, wrapped into special decorator')
-    parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
-    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
-
-    # fmt:on
-    args = vars(parser.parse_args())
-    args.pop("config", None)
-
-    if args.pop("increase_file_limit"):
-        increase_file_limit()
-
-    compression_type = args.pop("compression")
-    compression = getattr(CompressionType, compression_type)
-
-    use_auth_token = args.pop("use_auth_token")
-    args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
-
-    server = Server.create(**args, start=True, compression=compression)
-
-    try:
-        server.join()
-    except KeyboardInterrupt:
-        logger.info("Caught KeyboardInterrupt, shutting down")
-    finally:
-        server.shutdown()
-
-
-if __name__ == "__main__":
-    main()

+ 130 - 0
load_balancer.py

@@ -0,0 +1,130 @@
+import heapq
+import random
+import threading
+from contextlib import contextmanager
+from typing import Dict, List, Tuple
+
+from hivemind import RemoteExpert, TimedStorage, PeerID
+from hivemind.dht import DHT
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertPrefix, ExpertUID, ExpertInfo
+from hivemind.utils.performance_ema import PerformanceEMA
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class LoadBalancer:
+    def __init__(self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
+                 **kwargs):
+        self.dht, self.key = dht, key
+        self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
+        self.experts = TimedStorage[ExpertUID, PeerID]()
+        self.blacklist = TimedStorage[ExpertUID, type(None)]()
+        self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
+        self.queue: List[Tuple[float, float, ExpertUID]] = []
+        self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
+        self.lock = threading.Lock()
+        self.is_alive = threading.Event()
+        self.is_alive.set()
+        self.update_trigger, self.update_finished = threading.Event(), threading.Event()
+        self.update_period, self.last_update = update_period, get_dht_time()
+        self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
+        self.update_thread.start()
+        self._p2p = RemoteExpertWorker.run_coroutine(self.dht.replicate_p2p())
+
+    def update_experts_in_background(self):
+        while self.is_alive.is_set():
+            time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
+            try:
+                self.update_trigger.wait(timeout=time_to_next_update)
+                # update triggered by main thread
+            except TimeoutError:
+                pass  # update triggered by refresh_period
+
+            self.update_trigger.clear()
+            response = self.dht.get(self.key, latest=True)
+            if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
+                for index, expert_info in response.value.items():
+                    try:
+                        (expert_uid, peer_id), expiration_time = expert_info
+
+                        maybe_banned = self.blacklist.get(expert_uid)
+                        if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
+                            self._add_expert(expert_uid, peer_id, expiration_time)
+                        else:
+                            logger.debug(f"Not adding expert {expert_uid} (blacklisted).")
+                    except Exception as e:
+                        logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
+            else:
+                logger.warning(f"Could not refresh experts, dht info key contains {response}, "
+                               f"will retry in {time_to_next_update}s")
+            if len(self.queue) == 0:
+                logger.warning("Update routine finished, but still no experts available.")
+
+            self.last_update = get_dht_time()
+            self.update_finished.set()
+
+    def _add_expert(self, uid: ExpertUID, peer_id: PeerID, expiration_time: DHTExpiration):
+        with self.lock:
+            self.experts.store(uid, peer_id, expiration_time)
+            if uid not in self.uid_to_queue:
+                logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
+                self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
+                base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
+                heap_entry = (base_load, random.random(), uid)
+                heapq.heappush(self.queue, heap_entry)
+                self.uid_to_queue[uid] = heap_entry
+            else:
+                logger.debug(f"Refreshing existing module: {uid}, new expiration time = {expiration_time:.3f}.")
+
+    def _ban_expert(self, uid: ExpertUID):
+        with self.lock:
+            maybe_expert = self.experts.get(uid)
+            expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
+            self.blacklist.store(uid, None, expiration_time)
+            self.uid_to_queue.pop(uid, None)
+            self.throughputs.pop(uid, None)
+            del self.experts[uid]
+            logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
+
+    @contextmanager
+    def use_another_expert(self, task_size: float) -> RemoteExpert:
+        while True:
+            if len(self.queue) == 0:
+                self.update_finished.clear()
+                self.update_trigger.set()
+                self.update_finished.wait()
+                continue
+
+            with self.lock:
+                current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
+                maybe_peer_id = self.experts.get(uid)
+                if maybe_peer_id is None:
+                    # remove expired expert from queue
+                    self.uid_to_queue.pop(uid, None)
+                    self.throughputs.pop(uid, None)
+                if self.uid_to_queue.get(uid) != heap_entry:
+                    continue  # skip uids that are banned or expired
+
+                if self.throughputs[uid].num_updates != 0:
+                    expected_time_taken = task_size / self.throughputs[uid].samples_per_second
+                else:
+                    expected_time_taken = self.initial_throughput * task_size
+                new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
+                heapq.heappush(self.queue, new_heap_entry)
+                self.uid_to_queue[uid] = new_heap_entry
+                break
+        try:
+            with self.throughputs[uid].update_threadsafe(task_size):
+                logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
+                yield RemoteExpert(ExpertInfo(uid, PeerID.from_base58(maybe_peer_id.value)), self._p2p)
+        except BaseException:
+            self._ban_expert(uid)
+            raise
+
+    def shutdown(self):
+        self.is_alive.clear()
+        self.update_finished.clear()
+        self.update_trigger.set()
+        self.update_finished.wait()

+ 150 - 0
remote.py

@@ -0,0 +1,150 @@
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from torch.autograd.function import once_differentiable
+
+import hivemind
+from load_balancer import LoadBalancer
+from hivemind.moe.client.expert import DUMMY, expert_forward
+from hivemind.proto import runtime_pb2
+from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
+
+logger = get_logger(__name__)
+
+
+class BalancedRemoteExpert(nn.Module):
+    """
+    A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
+    ToDo docstring, similar to hivemind.RemoteExpert
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        uid_prefix: str,
+        grid_size: Tuple[int, ...],
+        forward_timeout: Optional[float] = None,
+        backward_timeout: Optional[float] = None,
+        update_period: float = 30.0,
+        backward_task_size_multiplier: float = 2.5,
+        **kwargs,
+    ):
+        super().__init__()
+        if uid_prefix.endswith(".0."):
+            logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
+        assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
+        self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
+        self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
+        self.backward_task_size_multiplier = backward_task_size_multiplier
+        self.expert_balancer = LoadBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
+        self._expert_info = None  # expert['info'] from one of experts in the grid
+
+    def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
+        """
+        Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
+
+        :param args: input tensors that will be passed to each expert after input, batch-first
+        :param kwargs: extra keyword tensors that will be passed to each expert, batch-first
+        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
+        """
+        assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
+        kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
+
+        if self._expert_info is None:
+            raise NotImplementedError()
+        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+
+        forward_inputs = (args, kwargs)
+
+        if not nested_compare(forward_inputs, self.info["forward_schema"]):
+            raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+        flat_inputs = list(nested_flatten(forward_inputs))
+        forward_task_size = flat_inputs[0].shape[0]
+
+        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
+        flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
+                                                       self.expert_balancer,
+                                                       self.info,
+                                                       self.forward_timeout,
+                                                       self.backward_timeout,
+                                                       forward_task_size,
+                                                       forward_task_size * self.backward_task_size_multiplier,
+                                                       *flat_inputs)
+
+        return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
+
+    @property
+    def info(self):
+        while self._expert_info is None:
+            try:
+                with self.expert_balancer.use_another_expert(1) as chosen_expert:
+                    self._expert_info = chosen_expert.info
+            except BaseException as e:
+                logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
+        return self._expert_info
+
+
+class _BalancedRemoteModuleCall(torch.autograd.Function):
+    """Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
+
+    @staticmethod
+    def forward(
+            ctx,
+            dummy: torch.Tensor,
+            expert_balancer: LoadBalancer,
+            info: Dict[str, Any],
+            forward_timeout: float,
+            backward_timeout: float,
+            forward_task_size: float,
+            backward_task_size: float,
+            *inputs: torch.Tensor,
+            ) -> Tuple[torch.Tensor, ...]:
+        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+        # detach to avoid pickling the computation graph
+        ctx.expert_balancer, ctx.info = expert_balancer, info
+        ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
+        ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
+        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
+        ctx.save_for_backward(*inputs)
+
+        serialized_tensors = [
+            serialize_torch_tensor(inp, proto.compression)
+            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        ]
+        while True:
+            try:
+                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                    deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(
+                        chosen_expert.uid, inputs, serialized_tensors, chosen_expert.stub))
+                break
+            except BaseException as e:
+                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
+
+        return tuple(deserialized_outputs)
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+        raise NotImplementedError("Backward is not yet implemented in this example")
+        # grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+        # inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
+        # backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
+        # serialized_tensors = [
+        #     serialize_torch_tensor(tensor, proto.compression)
+        #     for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        # ]
+        # while True:
+        #     try:
+        #         with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
+        #             backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
+        #             grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
+        #         break
+        #     except BaseException as e:
+        #         logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
+        # deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        # return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)

+ 0 - 5
src/__init__.py

@@ -1,5 +0,0 @@
-from .bloom import *
-from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
-from .dht_utils import declare_active_modules, get_remote_module
-
-__version__ = "0.1"

+ 0 - 1
src/bloom/__init__.py

@@ -1 +0,0 @@
-from src.bloom.model import BloomBlock, BloomForYou, BloomModel, DistributedBloomConfig

+ 0 - 264
src/bloom/block.py

@@ -1,264 +0,0 @@
-"""
-Bloom intermediate layer
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.quantized.dynamic.modules.linear
-
-from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
-                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
-
-
-class BloomAttention(nn.Module):
-    def __init__(self, config, layer_number=None):
-        super().__init__()
-
-        self.hidden_size = config.hidden_size
-        self.num_heads = config.n_head
-        self.head_dim = self.hidden_size // self.num_heads
-        self.split_size = self.hidden_size
-        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
-        self.masked_softmax_fusion = config.masked_softmax_fusion
-        self.hidden_dropout = config.hidden_dropout
-
-        if self.head_dim * self.num_heads != self.hidden_size:
-            raise ValueError(
-                f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
-                f" {self.num_heads})."
-            )
-
-        # Layer-wise attention scaling
-        self.layer_number = max(1, layer_number)
-        self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
-
-        # Scaled Softmax
-        self.scale_mask_softmax = BloomScaledSoftmax(
-            self.masked_softmax_fusion,
-            attention_mask_func,
-            self.attention_softmax_in_fp32,
-            self.layer_number,
-        )
-
-        if config.compression == "qint8":
-            self.query_key_value = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-            self.dense = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-        else:
-            self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
-            self.dense = nn.Linear(self.hidden_size, self.hidden_size)
-
-        self.attention_dropout = nn.Dropout(config.attention_dropout)
-
-    def forward(
-        self,
-        hidden_states,
-        residual,
-        layer_past=None,
-        attention_mask=None,
-        alibi=None,
-        head_mask=None,
-        use_cache=False,
-        output_attentions=False,
-    ):
-        if alibi is None:
-            current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
-            alibi = build_alibi_tensor(
-                current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
-            )
-
-        # hidden_states: [batch_size, seq_length, hidden_size]
-        # apply preprocessing if the input is padded
-        if attention_mask is not None:
-            alibi = pre_process_alibi_for_pad(alibi, attention_mask)
-        # otherwise repeat alibi tensor with the batch size
-        else:
-            alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
-
-        mixed_x_layer = self.query_key_value(hidden_states)
-
-        # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
-        new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
-        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
-
-        # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3  [batch_size, seq_length, num_heads, head_dim]
-        (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
-
-        if layer_past is not None:
-            past_key, past_value = layer_past
-            key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
-            value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
-
-        if use_cache is True:
-            present = (key_layer, value_layer)
-        else:
-            present = None
-
-        # [batch_size, head_dim, q_length, k_length]
-        output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
-
-        # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
-        query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
-
-        # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
-        key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
-
-        # Raw attention scores. [batch_size * num_heads, q_length, k_length]
-        beta = 1.0 / self.layer_number
-
-        matmul_result = torch.baddbmm(
-            alibi,
-            query_layer.transpose(1, 0),
-            key_layer.transpose(1, 0).transpose(1, 2),
-            beta=beta,
-            alpha=(1.0 / self.norm_factor),
-        )
-
-        # change view to [batch_size, num_heads, q_length, k_length]
-        attention_scores = matmul_result.view(*output_size)
-
-        # attention scores and attention mask [b, np, sq, sk]
-        max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
-        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
-        attention_probs = self.attention_dropout(attention_probs)
-
-        if head_mask is not None:
-            attention_probs = attention_probs * head_mask
-
-        # context layer shape: [batch_size, num_heads, q_length, head_dim]
-        output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
-
-        # change view [k_length, batch_size x num_heads, head_dim]
-        value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
-
-        # change view [batch_size x num_heads, q_length, k_length]
-        attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
-
-        # matmul: [batch_size * num_heads, q_length, head_dim]
-        context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
-
-        # change view [batch_size, num_heads, q_length, head_dim]
-        context_layer = context_layer.view(*output_size)
-
-        # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
-        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
-
-        # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
-        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
-
-        context_layer = context_layer.view(*new_context_layer_shape)
-
-        # Output. [q_length, batch_size, hidden_size]
-
-        # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
-        output_tensor = self.dense(context_layer)
-        output = output_tensor.transpose(1, 0)
-
-        output = dropout_add(output, residual, self.hidden_dropout, self.training)
-
-        outputs = (output, present)
-        if output_attentions:
-            outputs += (attention_probs,)
-
-        return outputs
-
-
-class BloomMLP(nn.Module):
-    def __init__(self, config):
-        super().__init__()
-        self.hidden_size = config.hidden_size
-        if config.compression == "qint8":
-            self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-            self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
-                4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-        else:
-            self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
-            self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
-        self.hidden_dropout = config.hidden_dropout
-        self.gelu_impl = BloomGelu()
-
-    def forward(self, hidden_states, residual):
-        hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
-        intermediate_output = self.dense_4h_to_h(hidden_states)
-        output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
-        return output
-
-
-class BloomBlock(nn.Module):
-    def __init__(self, config, layer_number=None):
-        super().__init__()
-        self.hidden_size = config.hidden_size
-
-        self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
-        self.n_head = config.n_head
-        self.self_attention = BloomAttention(config, layer_number=layer_number)
-        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
-
-        self.mlp = BloomMLP(config)
-
-        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
-        self.hidden_dropout = config.hidden_dropout
-
-    def forward(
-        self,
-        hidden_states,
-        layer_past=None,
-        attention_mask=None,
-        head_mask=None,
-        use_cache=False,
-        output_attentions=False,
-        alibi=None,
-    ):
-        # hidden_states: [batch_size, seq_length, hidden_size]
-
-        # Layer norm at the beginning of the transformer layer.
-        layernorm_output = self.input_layernorm(hidden_states)
-
-        # Layer norm post the self attention.
-        if self.apply_residual_connection_post_layernorm:
-            residual = layernorm_output
-        else:
-            residual = hidden_states
-
-        # Self attention.
-        attn_outputs = self.self_attention(
-            layernorm_output,
-            residual,
-            layer_past=layer_past,
-            attention_mask=attention_mask,
-            alibi=alibi,
-            head_mask=head_mask,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-        )
-
-        attention_output = attn_outputs[0]
-
-        outputs = attn_outputs[1:]
-
-        layernorm_output = self.post_attention_layernorm(attention_output)
-
-        # Get residual
-        if self.apply_residual_connection_post_layernorm:
-            residual = layernorm_output
-        else:
-            residual = attention_output
-
-        # MLP.
-        output = self.mlp(layernorm_output, residual)
-
-        if use_cache:
-            outputs = (output,) + outputs
-        else:
-            outputs = (output,) + outputs[1:]
-
-        return outputs  # hidden_states, present, attentions

+ 0 - 80
src/bloom/from_pretrained.py

@@ -1,80 +0,0 @@
-"""
-Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
-If necessary, one can rewrite this to implement a different behavior, such as:
- - loading files from a local data source (e.g. S3)
- - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
- - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
-
-"""
-from __future__ import annotations
-
-from typing import Optional, OrderedDict, Union
-
-import torch
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from transformers.modeling_utils import WEIGHTS_NAME
-from transformers.utils.hub import cached_path, hf_bucket_url
-
-from src.bloom import BloomBlock, DistributedBloomConfig
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-CLIENT_BRANCH = "client"
-BLOCK_BRANCH_PREFIX = "block_"
-USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
-FORCE_DOWNLOAD = False
-RESUME_DOWNLOAD = False
-LOCAL_FILES_ONLY = False
-
-
-def load_pretrained_block(
-    converted_model_name_or_path: str,
-    block_index: int,
-    config: Optional[DistributedBloomConfig] = None,
-    torch_dtype: Union[torch.dtype, str] = "auto",
-    use_auth_token: Optional[str] = None,
-) -> BloomBlock:
-    """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
-    if config is None:
-        config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
-    block = BloomBlock(config, layer_number=block_index)
-    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
-    block.load_state_dict(state_dict)
-
-    if torch_dtype == "auto":
-        with torch.no_grad():
-            for name, param in block.named_parameters():
-                assert name in state_dict, f"{name} not in state dict"
-                param.data = param.data.to(state_dict[name].dtype)
-    else:
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        block = block.to(dtype=torch_dtype)
-
-    report = block.load_state_dict(state_dict, strict=True)
-    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
-    return block
-
-
-def _load_state_dict(
-    pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
-) -> OrderedDict[str, torch.Tensor]:
-    revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
-    archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
-
-    # Load from URL or cache if already cached
-    resolved_archive_file = cached_path(
-        archive_file,
-        cache_dir=None,
-        force_download=FORCE_DOWNLOAD,
-        proxies=None,
-        resume_download=RESUME_DOWNLOAD,
-        local_files_only=LOCAL_FILES_ONLY,
-        use_auth_token=use_auth_token,
-        user_agent=USER_AGENT,
-    )
-    state_dict = torch.load(resolved_archive_file, map_location="cpu")
-    return state_dict
-
-
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 0 - 328
src/bloom/model.py

@@ -1,328 +0,0 @@
-"""
-PyTorch BLOOM model that implements several memory-efficient modes.
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-
-import torch
-import torch.utils.checkpoint
-from hivemind import use_hivemind_log_handler
-from torch import nn
-from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
-                                     add_start_docstrings_to_model_forward)
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
-from transformers.modeling_utils import PreTrainedModel
-from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
-from transformers.utils import logging
-
-from src.bloom.block import BloomBlock
-from src.bloom.ops import build_alibi_tensor
-
-use_hivemind_log_handler("in_root_logger")
-logger = logging.get_logger(__file__)
-
-_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
-_CONFIG_FOR_DOC = "DistributedBloomConfig"
-_TOKENIZER_FOR_DOC = "BloomTokenizer"
-
-
-class DistributedBloomConfig(_VanillaBloomConfig):
-    compression: str = "none"
-    slow_but_exact: bool = False
-
-
-class BloomPreTrainedModel(PreTrainedModel):
-    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
-    """
-    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
-    models.
-    """
-
-    config_class = DistributedBloomConfig
-    base_model_prefix = "transformer"
-    supports_gradient_checkpointing = True
-    _no_split_modules = ["BloomBlock"]
-
-    def __init__(self, *inputs, **kwargs):
-        super().__init__(*inputs, **kwargs)
-
-    def _init_weights(self, module):
-        """Initialize the weights."""
-        if isinstance(module, (nn.Linear)):
-            # Slightly different from the TF version which uses truncated_normal for initialization
-            # cf https://github.com/pytorch/pytorch/pull/5617
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.bias is not None:
-                module.bias.data.zero_()
-        elif isinstance(module, nn.Embedding):
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.padding_idx is not None:
-                module.weight.data[module.padding_idx].zero_()
-        elif isinstance(module, LayerNorm):
-            module.bias.data.zero_()
-            module.weight.data.fill_(1.0)
-
-    def _set_gradient_checkpointing(self, module, value=False):
-        if isinstance(module, BloomModel):
-            module.gradient_checkpointing = value
-
-
-BLOOM_START_DOCSTRING = r"""
-
-    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
-    library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
-
-    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
-    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
-    and behavior.
-
-    Parameters:
-        config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
-            Initializing with a config file does not load the weights associated with the model, only the
-            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-BLOOM_INPUTS_DOCSTRING = r"""
-    Args:
-        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
-            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
-            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
-            sequence tokens in the vocabulary.
-
-            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
-            `input_ids`.
-
-            Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
-            [`PreTrainedTokenizer.__call__`] for details.
-
-            [What are input IDs?](../glossary#input-ids)
-        past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
-            Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
-            `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
-            their past given to this model should not be passed as `input_ids` as they have already been computed.
-        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
-            - 1 for tokens that are **not masked**,
-            - 0 for tokens that are **masked**.
-
-            [What are attention masks?](../glossary#attention-mask)
-        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
-            config.max_position_embeddings - 1]`.
-
-            [What are position IDs?](../glossary#position-ids)
-        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
-            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
-
-            - 1 indicates the head is **not masked**,
-            - 0 indicates the head is **masked**.
-
-        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
-            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
-            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
-            model's internal embedding lookup matrix.
-
-            If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
-            `past_key_values`).
-        use_cache (`bool`, *optional*):
-            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
-            `past_key_values`).
-        output_attentions (`bool`, *optional*):
-            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
-            tensors for more detail.
-        output_hidden_states (`bool`, *optional*):
-            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
-            more detail.
-        return_dict (`bool`, *optional*):
-            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-@add_start_docstrings(
-    "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
-    BLOOM_START_DOCSTRING,
-)
-class BloomModel(BloomPreTrainedModel):
-    def __init__(self, config):
-        super().__init__(config)
-        assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
-
-        self.embed_dim = config.hidden_size
-        self.n_head = config.n_head
-
-        # Embedding + LN Embedding
-
-        # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)  # dtype=config.torch_dtype
-        self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
-
-        # Transformer blocks
-        self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
-
-        # Final Layer Norm
-        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
-
-        self.gradient_checkpointing = False
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-        # Forbid accumulate grads for embeddings and layernorm
-        self.set_requires_grad(False)
-
-    def get_input_embeddings(self):
-        return self.word_embeddings
-
-    def set_input_embeddings(self, new_embeddings):
-        self.word_embeddings = new_embeddings
-
-    def set_requires_grad(self, value):
-        for p in self.parameters():
-            p.requires_grad = value
-
-    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
-    @add_code_sample_docstrings(
-        processor_class=_TOKENIZER_FOR_DOC,
-        checkpoint=_CHECKPOINT_FOR_DOC,
-        output_type=BaseModelOutputWithPastAndCrossAttentions,
-        config_class=_CONFIG_FOR_DOC,
-    )
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-        output_hidden_states = (
-            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
-        )
-        use_cache = use_cache if use_cache is not None else self.config.use_cache
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-        if position_ids is not None:
-            logger.warning("position_ids are ignored in this bloom implementation")
-        elif input_ids is not None:
-            input_shape = input_ids.size()
-            input_ids = input_ids.view(-1, input_shape[-1])
-        elif inputs_embeds is not None:
-            input_shape = inputs_embeds.size()[:-1]
-        else:
-            raise ValueError("You have to specify either input_ids or inputs_embeds")
-
-        if past_key_values is None:
-            past_key_values = tuple([None] * len(self.h))
-
-        # Prepare head mask if needed
-        # 1.0 in head_mask indicate we keep the head
-        # attention_probs has shape bsz x n_head x N x N
-        # head_mask has shape n_layer x batch x n_head x N x N
-        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
-
-        output_shape = input_shape + (hidden_states.size(-1),)
-
-        presents = () if use_cache else None
-        all_self_attentions = () if output_attentions else None
-        all_hidden_states = () if output_hidden_states else None
-
-        # Compute alibi tensor: check build_alibi_tensor documentation
-        current_sequence_length = hidden_states.shape[1]
-        if past_key_values and past_key_values[0]:
-            current_sequence_length += past_key_values[0][0].shape[1]
-
-        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
-
-            if output_hidden_states:
-                all_hidden_states = all_hidden_states + (hidden_states,)
-
-            if self.gradient_checkpointing and self.training:
-
-                if use_cache:
-                    logger.warning(
-                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
-                    )
-                    use_cache = False
-
-                def create_custom_forward(module):
-                    def custom_forward(*inputs):
-                        # None for past_key_value
-                        return module(*inputs, use_cache, output_attentions, alibi=None)
-
-                    return custom_forward
-
-                outputs = torch.utils.checkpoint.checkpoint(
-                    create_custom_forward(block),
-                    hidden_states,
-                    None,
-                    attention_mask,
-                    head_mask[i],
-                )
-            else:
-                outputs = block(
-                    hidden_states,
-                    layer_past=layer_past,
-                    attention_mask=attention_mask,
-                    head_mask=head_mask[i],
-                    use_cache=use_cache,
-                    output_attentions=output_attentions,
-                    alibi=None,
-                )
-
-            hidden_states = outputs[0]
-            if use_cache is True:
-                presents = presents + (outputs[1],)
-
-            if output_attentions:
-                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
-
-        # Add last hidden state
-        hidden_states = self.ln_f(hidden_states)
-
-        if output_hidden_states:
-            all_hidden_states = all_hidden_states + (hidden_states,)
-
-        hidden_states = hidden_states.view(output_shape)
-
-        if not return_dict:
-            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
-
-        return BaseModelOutputWithPastAndCrossAttentions(
-            last_hidden_state=hidden_states,
-            past_key_values=presents,
-            hidden_states=all_hidden_states,
-            attentions=all_self_attentions,
-        )
-
-
-@add_start_docstrings(
-    """
-    The Bloom interface for various applications, e.g., inference, classification...
-    """,
-    BLOOM_START_DOCSTRING,
-)
-class BloomForYou(BloomPreTrainedModel):
-    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
-
-    def __init__(self, config):
-        super().__init__(config)
-        self.transformer = BloomModel(config)
-        self.lm_head = None
-
-        # Initialize weights and apply final processing
-        self.post_init()

+ 0 - 246
src/bloom/ops.py

@@ -1,246 +0,0 @@
-"""
-Utility operations used in the the BLOOM model
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-import math
-
-import torch
-import torch.autograd
-import torch.nn.functional as F
-from torch import nn
-
-
-def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
-    """Split a tensor along its last dimension.
-
-    Args:
-        tensor: ([`torch.tensor`], *required*):
-            input tensor to split
-        num_partitions ([`int`], *required*):
-            number of partitions to split the tensor
-        contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
-            If True, make each chunk contiguous in memory.
-    """
-    # Get the size and dimension.
-    last_dim = tensor.dim() - 1
-    numerator, denominator = tensor.size()[last_dim], num_partitions
-    if not (numerator % denominator == 0):
-        raise ValueError(f"{numerator} is not divisible by {denominator}")
-    last_dim_size = numerator // denominator
-    # Split.
-    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
-    # Note: torch.split does not create contiguous tensors by default.
-    if contiguous_split_chunks:
-        return tuple(chunk.contiguous() for chunk in tensor_list)
-
-    return tensor_list
-
-
-def attention_mask_func(attention_scores, attention_mask, causal_mask):
-    if attention_mask.dtype == torch.bool:
-        attention_mask_bool = ~attention_mask
-    else:
-        attention_mask_bool = (1 - attention_mask).bool()
-
-    query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
-    padded_causal_mask = (
-        attention_mask_bool[:, None, key_length - query_length : key_length, None]
-        + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
-    ).bool()
-    padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
-    # Make use of floats
-    return (
-        attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
-        padded_causal_mask,
-    )
-
-
-def build_alibi_tensor(
-    max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
-) -> torch.Tensor:
-    """
-    Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
-    relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
-    `softmax(l+a) = softmax(l)`. Based on
-    https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
-    Args:
-    Returns tensor shaped (n_head, 1, max_seq_len)
-        max_seq_len: (`int`, *required*):
-            max sequence length
-        n_head: (`int`, *required*):
-            number of heads
-        dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
-            dtype of the output tensor
-        device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
-            device of the output alibi tensor
-    """
-    closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
-    base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
-    powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
-    slopes = torch.pow(base, powers)
-
-    if closest_power_of_2 != n_head:
-        extra_base = torch.tensor(
-            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
-        )
-        num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
-        extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
-        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
-
-    lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
-    return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
-
-
-def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
-    """
-    Args:
-    Pre-process the alibi tensor for padding.
-        alibi: ([`torch.tensor`], *required*):
-            alibi tensor to pre-process
-        attention_mask: ([`torch.tensor`], *required*):
-            attention mask to pre-process
-    """
-    assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
-    unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
-    # ^-- [batch, max_len], values correspond to element indices after removing padding
-    # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
-    alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
-    return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
-
-
-def dropout_add(x, residual, prob, training):
-    """
-    Dropout add function
-
-    Args:
-        x (`torch.tensor`, *required*):
-            input tensor
-        residual (`torch.tensor`, *rquired*):
-            esidual tensor
-        prob (`float`, *required*):
-            dropout probability
-        training (`bool`, *required*):
-            training mode
-    """
-    out = nn.functional.dropout(x, p=prob, training=training)
-    out = residual + out
-    return out
-
-
-def bloom_gelu_forward(x):
-    """
-    Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
-    make the model jitable.
-
-    Args:
-        x (`torch.tensor`, *required*):
-            input hidden states
-    """
-    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
-
-
-def bloom_gelu_back(g, x):
-    """
-    gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
-    0.3989423 * x * torch.exp(-0.5 * x * x)
-
-    Args:
-        g (`torch.tensor`, *required*):
-            gradient output tensor
-        x (`torch.tensor`, *required*):
-            input tensor
-    """
-    x = x[0]  # x is a tuple of 1 element, needs to unpack it first
-    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
-    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
-    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
-    return ff * g
-
-
-class GeLUFunction(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, input):
-        ctx.save_for_backward(input)
-        return bloom_gelu_forward(input)
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        input = ctx.saved_tensors
-        tmp = bloom_gelu_back(grad_output, input)
-        return tmp
-
-
-class BloomGelu(nn.Module):
-    """
-    BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
-    torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
-    copied from Megatron-DeepSpeed code and adapted for our needs
-
-    See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
-
-    """
-
-    def __init__(self):
-        super().__init__()
-
-    def forward(self, x):
-        if self.training:
-            return GeLUFunction.apply(x)
-        else:
-            return bloom_gelu_forward(x)
-
-
-class BloomScaledSoftmax(nn.Module):
-    """
-    fused operation: scaling + mask + softmax
-
-    Args:
-        input_in_fp16 (`bool`, *required*):
-            flag to indicate if input in fp16 data format.
-        input_in_bf16 (`bool`, *required*):
-            flag to indicate if input in bf16 data format.
-        scaled_masked_softmax_fusion (`bool`, *required*):
-            flag to indicate user want to use softmax fusion
-        mask_func (`function`, *required*):
-            mask function to be applied.
-        softmax_in_fp32 (`bool`, *required*):
-            if true, softmax in performed at fp32 precision.
-        scale (`float`, *required*):
-            scaling factor used in input tensor scaling.
-    """
-
-    def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
-        super().__init__()
-        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
-        self.mask_func = mask_func
-        self.softmax_in_fp32 = softmax_in_fp32
-        self.scale = scale
-
-        if not (self.scale is None or softmax_in_fp32):
-            raise ValueError("softmax should be in fp32 when scaled")
-
-    def forward(self, input, mask, max_positions):
-        input_dtype = input.dtype
-        input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
-        softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
-
-        if self.scale is not None:
-            input = input * self.scale
-
-        if mask is None:
-            mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
-
-        mask = mask.to(input.device)
-        causal_mask = (
-            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
-            .view(1, 1, max_positions, max_positions)
-            .to(input.device)
-        )
-        mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
-        probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
-
-        if input_in_16bit and self.softmax_in_fp32:
-            probs = probs.to(dtype=input_dtype)
-
-        return probs

+ 0 - 1
src/client/__init__.py

@@ -1 +0,0 @@
-from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession

+ 0 - 135
src/client/remote_block.py

@@ -1,135 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-from __future__ import annotations
-
-import asyncio
-import random
-from typing import Any, AsyncIterator, Dict, Optional
-
-import torch
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.p2p import P2P, StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
-
-from src.data_structures import RemoteModuleInfo
-from src.dht_utils import ModuleUID
-from src.server.handler import TransformerConnectionHandler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class RemoteTransformerBlock(RemoteExpert):
-    """A class that interacts with a remote module on a specific server for forward/backward or inference"""
-
-    def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids)))  # TODO replace this
-        super().__init__(peer_info, p2p)
-
-    @property
-    def stub(self) -> StubBase:
-        return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
-
-    def forward(self, inputs: torch.Tensor, **kwargs):
-        for k, v in kwargs.items():
-            assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
-        return super().forward(inputs)
-
-    def inference_session(self) -> RemoteTransformerBlockInferenceSession:
-        """Initialize a new inference session with the specified remote server"""
-        _ = self.info  # create _info manually since the built-in property will not work inside RemoteExpertWorker
-        return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
-
-    def begin_inference_session(self):
-        logger.warning("beging_inference_session was renamed to just inference_session")
-        return self.inference_session()
-
-
-class RemoteTransformerBlockInferenceSession:
-    """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
-
-    def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
-        self.uid, self.info = uid, info
-        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
-        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
-        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
-        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self.stepped = False
-        self.closed = False
-
-    @classmethod
-    async def _create(
-        cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
-    ) -> RemoteTransformerBlockInferenceSession:
-        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
-        inputs_queue = asyncio.Queue()
-        outputs_stream = await remote_module.stub.rpc_inference(
-            cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
-        )
-        return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
-
-    @staticmethod
-    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
-        while True:
-            next_input_message = await asyncio.wait_for(queue.get(), timeout)
-            yield next_input_message
-            if not next_input_message.uid and not next_input_message.tensors:
-                break  # this message means "done sending"
-
-    def step(self, new_hidden_states: torch.Tensor):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
-        if self.closed:
-            raise Exception("Session is closed, cannot perform step")
-        # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
-        outputs_serialized = RemoteExpertWorker.run_coroutine(
-            self._step(
-                runtime_pb2.ExpertRequest(
-                    uid=self.uid,
-                    tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
-                    ],
-                )
-            )
-        )
-        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
-        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-        return outputs[0]
-
-    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
-        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
-        await self._inputs_queue.put(inputs_serialized)
-        self.stepped = True
-        return await anext(self._outputs_stream)
-
-    def close(self):
-        """Finish a given inference session, close the underlying connection"""
-        if self._outputs_stream is None:
-            return  # already closed
-        RemoteExpertWorker.run_coroutine(self._aclose_stream())
-        self._outputs_stream = self._inputs_queue = None
-        self.closed = True
-
-    async def _aclose_stream(self):
-        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
-        if self._outputs_stream is None:
-            return  # already closed
-        if self.stepped:
-            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
-            try:
-                await anext(self._outputs_stream)
-            except StopAsyncIteration:
-                pass
-
-    def __del__(self):
-        self.close()
-
-    def __enter__(self):
-        assert not self.closed
-        return self
-
-    def __exit__(self, *exc_details):
-        self.close()

+ 0 - 127
src/client/remote_model.py

@@ -1,127 +0,0 @@
-# this code is in active development, interfaces may change
-import os
-from typing import Optional, Tuple, Union
-
-import hivemind
-import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler
-from torch.nn import CrossEntropyLoss
-from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
-
-from src.bloom import BloomForYou, DistributedBloomConfig
-from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
-from src.client.remote_sequential import RemoteSequential
-from src.data_structures import UID_DELIMITER
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class DistributedBloomForYou(BloomForYou):
-    """BloomModel, but all transformer layers are hosted by the swarm"""
-
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
-        n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
-        super().__init__(config)
-        assert len(self.transformer.h) == 0
-        config.n_layer = n_layer
-        self.transformer.h = RemoteSequential(config, dht, prefix)
-
-    @classmethod
-    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
-        if "initial_peers" not in kwargs:
-            raise ValueError("Please specify initial_peers=...")
-
-        dht = hivemind.DHT(
-            initial_peers=kwargs.pop("initial_peers"), client_mode=kwargs.pop("client_mode", True), start=True
-        )
-
-        if "prefix" not in kwargs:
-            logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
-            assert (
-                UID_DELIMITER not in pretrained_model_name_or_path
-            ), f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
-        prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
-
-        config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
-        model = cls(config, dht, prefix)
-        model.transformer.load_state_dict(
-            _load_state_dict(pretrained_model_name_or_path, use_auth_token=kwargs.get("use_auth_token")), strict=True
-        )
-        return model
-
-
-class DistributedBloomForCausalLM(DistributedBloomForYou):
-    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
-
-    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
-        # only last token for inputs_ids if past is defined in kwargs
-        if past:
-            input_ids = input_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
-        else:
-            position_ids = None
-        return {
-            "input_ids": input_ids,
-            "past_key_values": past,
-            "use_cache": kwargs.get("use_cache"),
-            "position_ids": position_ids,
-            "attention_mask": attention_mask,
-        }
-
-    def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
-        r"""
-        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
-            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
-            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
-        """
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-        transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
-
-        # Switch dtype in case word_embeddings are fp16
-        word_embeddings = self.transformer.word_embeddings.weight.t()
-        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
-        lm_logits = (hidden_states @ word_embeddings).float()
-
-        loss = None
-        if labels is not None:
-            # Shift so that tokens < n predict n
-            shift_logits = lm_logits[..., :-1, :].contiguous()
-            shift_labels = labels[..., 1:].contiguous()
-            # Flatten the tokens
-            loss_fct = CrossEntropyLoss()
-            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
-
-        if not return_dict:
-            output = (lm_logits,) + transformer_outputs[1:]
-            return ((loss,) + output) if loss is not None else output
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=loss,
-            logits=lm_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-        )
-
-    @staticmethod
-    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
-        """
-        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
-        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
-        beam_idx at every generation step.
-        """
-        return tuple(
-            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
-            for layer_past in past
-        )

+ 0 - 94
src/client/remote_sequence_info.py

@@ -1,94 +0,0 @@
-from __future__ import annotations
-
-import dataclasses
-import threading
-from functools import partial
-from typing import List, NamedTuple, Optional, Sequence, Tuple
-
-from hivemind import DHT, PeerID
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
-from src.data_structures import ModuleUID, RemoteModuleInfo
-from src.dht_utils import _get_remote_module_infos
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
-
-
-@dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
-class RemoteSequenceInfo:
-    """Keeps and updates the meta-information about which peers host which blocks"""
-
-    dht: DHT
-    block_uids: List[ModuleUID, ...]
-    block_infos: List[Optional[RemoteModuleInfo], ...]
-    spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span], ...]
-    lock_changes: threading.Lock
-
-    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
-        self.dht = dht
-        self.block_uids = list(block_uids)
-        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
-        self.spans_by_priority = []
-        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
-        self.lock_changes = threading.Lock()
-        self.update_()
-
-        for uid, info in zip(self.block_uids, self.block_infos):
-            assert info is not None, f"Found no remote peers for block {uid}"
-        assert self.spans_by_priority and self.spans_containing_block
-
-    def update_(self):
-        with self.lock_changes:
-            self.update_block_infos_()
-            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
-
-    def update_block_infos_(self):
-        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
-        )
-        assert len(new_block_infos) == len(self.block_uids)
-        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
-            if info is None:
-                logger.warning(f"Found no block info for block {uid}")
-            if not isinstance(info, RemoteModuleInfo):
-                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.peer_ids:
-                logger.warning(f"Found no active peers for block {uid}")
-            if info.uid != uid:
-                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            if not isinstance(info.peer_ids, set):
-                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
-            self.block_infos[block_index] = info
-
-    @staticmethod
-    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
-        closed_spans = []
-        active_spans = {}
-        for block_index, info in enumerate(block_infos):
-            for peer_id in info.peer_ids:
-                if peer_id not in active_spans:
-                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
-                else:  # peer_id in active_spans
-                    active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
-
-            for peer_id in list(active_spans.keys()):
-                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
-                    closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans
-
-        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
-
-        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
-        for span in closed_spans:
-            for block_index in range(span.start, span.end):
-                spans_containing_block[block_index].append(span)
-
-        return closed_spans, spans_containing_block
-
-    def __len__(self):
-        return len(self.block_uids)

+ 0 - 134
src/client/remote_sequential.py

@@ -1,134 +0,0 @@
-from __future__ import annotations
-
-import contextlib
-import logging
-import random
-
-import torch
-from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo
-from torch import nn
-
-from src import DistributedBloomConfig, RemoteTransformerBlock
-from src.client.remote_sequence_info import RemoteSequenceInfo
-from src.data_structures import UID_DELIMITER
-from src.dht_utils import _create_remote_modules_from_infos
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class RemoteSequential(nn.Module):
-    """
-    A sequence of transformer blocks hosted by the swarm.
-    """
-
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
-        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
-        if prefix.endswith(UID_DELIMITER):
-            logger.warning(
-                f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
-                f"This will cause {self.__class__.__name__} to look for modules under "
-                f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
-            )
-
-        super().__init__()
-        self.config = config
-        self.dht = dht
-        self.prefix = prefix
-        self.max_retries = max_retries
-        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-
-        block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
-
-        logger.debug(f"Remote block uids: {block_uids}")
-        self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
-
-    def forward(self, inputs: torch.Tensor):
-        assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
-        for block_index in range(self.config.n_layer):
-            for retry_index in range(self.max_retries):
-                try:
-                    block = self[block_index]
-                    (outputs,) = block(inputs)
-                    assert isinstance(outputs, torch.Tensor)
-                    assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
-                    inputs = outputs
-                    break
-                except Exception as e:
-                    if retry_index == self.max_retries - 1:
-                        raise e
-                    else:
-                        logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
-        return inputs
-
-    def __getitem__(self, block_index: int):
-        assert 0 <= block_index < self.config.n_layer
-        (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
-        return module
-
-    def __iter__(self):
-        for block_index in range(self.config.n_layer):
-            yield self[block_index]
-
-    def __len__(self):
-        return len(self.remote_sequence_info)
-
-    def inference_session(self) -> RemoteSequentialInferenceSession:
-        self.remote_sequence_info.update_()
-        return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
-
-
-class RemoteSequentialInferenceSession:
-    """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
-
-    def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
-        self.remote_sequence_info = remote_sequence_info
-        self.p2p = p2p
-        self.closed = False
-        self.stack = contextlib.ExitStack()
-        self.active_sessions = []
-
-    def __enter__(self):
-        assert not self.closed
-        self.stack.__enter__()
-        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
-        current_block = 0
-        while current_block != len(self.remote_sequence_info):
-            candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
-            chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
-            assert chosen_span.start <= current_block < chosen_span.end
-
-            # TODO begin throwaway prototype code
-            remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
-            _ = remote.info  # TODO fix
-            span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
-            remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
-            self.active_sessions.append(remote.inference_session())
-            self.stack.enter_context(self.active_sessions[-1])
-            current_block = chosen_span.end
-            # TODO end throwaway prototype code
-
-        return self
-
-    def step(self, inputs: torch.Tensor):
-        assert not self.closed
-        for session in self.active_sessions:
-            outputs = session.step(inputs)
-            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
-            inputs = outputs
-        return inputs
-
-    def close(self, *exc_details):
-        """Finish a given inference session, close the underlying connection"""
-        if not self.closed:
-            self.stack.__exit__(*exc_details or (None, None, None))
-            self.active_sessions.clear()
-            self.closed = True
-
-    def __exit__(self, *exc_details):
-        self.close(*exc_details)
-
-    def __del__(self):
-        self.close()

+ 0 - 8
src/data_structures.py

@@ -1,8 +0,0 @@
-from typing import Collection, NamedTuple
-
-from hivemind import PeerID
-
-ModuleUID = str
-UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
-CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
-RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])

+ 0 - 132
src/dht_utils.py

@@ -1,132 +0,0 @@
-"""
-Utilities for declaring and retrieving active model layers using a shared DHT.
-"""
-from __future__ import annotations
-
-from functools import partial
-from typing import Dict, List, Optional, Sequence, Union
-
-from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
-
-import src
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-def declare_active_modules(
-    dht: DHT,
-    uids: Sequence[ModuleUID],
-    expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
-    wait: bool = True,
-) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
-    """
-    Declare that your node serves the specified modules; update timestamps if declared previously
-
-    :param uids: a list of module ids to declare
-    :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param throughput: optionally specify your performance in terms of compute throughput
-    :param expiration_time: declated modules will be visible for this many seconds
-    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
-    """
-    if isinstance(uids, str):
-        uids = [uids]
-    if not isinstance(uids, list):
-        uids = list(uids)
-    for uid in uids:
-        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
-    return dht.run_coroutine(
-        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
-        return_future=not wait,
-    )
-
-
-async def _declare_active_modules(
-    dht: DHT,
-    node: DHTNode,
-    uids: List[ModuleUID],
-    expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
-) -> Dict[ModuleUID, bool]:
-    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
-    return await node.store_many(
-        keys=uids,
-        subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[throughput] * len(uids),
-        expiration_time=expiration_time,
-        num_workers=num_workers,
-    )
-
-
-def get_remote_module(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    expiration_time: Optional[DHTExpiration] = None,
-    return_future: bool = False,
-) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
-    """
-    :param uid_or_uids: find one or more modules with these ids from across the DHT
-    :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
-    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-    :returns: a list of [RemoteTransformerBlock if found else None]
-    """
-    single_uid = isinstance(uid_or_uids, ModuleUID)
-    uids = [uid_or_uids] if single_uid else uid_or_uids
-    infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
-    )
-
-    if return_future:
-
-        async def _unpack(infos_future: MPFuture, dht: DHT):
-            p2p = await dht.replicate_p2p()
-            modules = _create_remote_modules_from_infos(await infos_future, p2p)
-            return modules[0] if single_uid else modules
-
-        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
-    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-    modules = _create_remote_modules_from_infos(infos, p2p)
-    return modules[0] if single_uid else modules
-
-
-async def _get_remote_module_infos(
-    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteModuleInfo]]:
-    if expiration_time is None:
-        expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
-    found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
-
-    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
-    for i, uid in enumerate(uids):
-        metadata = found[uid]
-        if metadata is None or not isinstance(metadata.value, dict):
-            if metadata is not None:
-                logger.error(f"Incorrect metadata for {uid}: {metadata}")
-            continue
-        valid_entries = set()
-        for maybe_peer_id, _unused_value in metadata.value.items():
-            try:
-                valid_entries.add(PeerID.from_base58(maybe_peer_id))
-            except:
-                logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
-        if valid_entries:
-            modules[i] = RemoteModuleInfo(uid, valid_entries)
-    return modules
-
-
-def _create_remote_modules_from_infos(
-    infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
-) -> List[Optional[src.RemoteTransformerBlock]]:
-    modules: List[Optional[src.RemoteTransformerBlock]] = []
-    for info in infos:
-        if info is not None:
-            modules.append(src.RemoteTransformerBlock(info, p2p))
-        else:
-            modules.append(None)
-    return modules

+ 0 - 0
src/server/__init__.py


+ 0 - 58
src/server/backend.py

@@ -1,58 +0,0 @@
-"""Code for serving bloom blocks via hivemind-server"""
-from typing import Sequence, Tuple
-
-import torch
-from hivemind.moe.server.module_backend import ModuleBackend
-from hivemind.moe.server.task_pool import TaskPool
-
-from src.bloom.from_pretrained import BloomBlock
-from src.server.cache import MemoryCache
-
-MAX_LENGTH = 2048
-
-
-class TransformerBackend(ModuleBackend):
-    """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
-
-    def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
-        super().__init__(*args, **kwargs)
-        assert isinstance(self.module, BloomBlock)
-        self.memory_cache = memory_cache
-        for name, param in self.module.named_parameters():
-            assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
-        for name, buf in self.module.named_buffers():
-            assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
-
-        self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
-
-    def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
-        with torch.inference_mode():
-            attention_cache_handle = int(cache_metadata[0, 0].item())
-            prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
-            assert (
-                hidden_states.ndim == 3
-            ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-
-            with self.memory_cache.use_cache(attention_cache_handle) as cache:
-                assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-                print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
-                hidden_states, (new_k, new_v) = self.module.forward(
-                    hidden_states, layer_past=layer_past, use_cache=True
-                )
-
-                # todo remove these asserts once we pass all tests
-                new_length = new_v.shape[1]
-                assert new_length > prefix_length
-                assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
-                assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
-                assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
-                assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
-                assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
-                cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
-                cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
-                return (hidden_states,)
-
-    def get_pools(self) -> Sequence[TaskPool]:
-        return self.forward_pool, self.backward_pool, self.inference_pool

+ 0 - 127
src/server/cache.py

@@ -1,127 +0,0 @@
-"""
-A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
-
-For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
-
-"""
-import contextlib
-import ctypes
-import multiprocessing as mp
-import os
-from typing import AsyncContextManager, Dict, Optional, Union
-
-import hivemind
-import torch
-from hivemind import use_hivemind_log_handler
-from hivemind.utils import TensorDescriptor, get_logger
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-Handle = int
-
-
-class MemoryCache:
-    """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
-
-    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
-        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
-        self.device = device
-        self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
-        self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
-        self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
-        self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
-        self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
-        self.runtime_pid = os.getpid()
-
-        self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
-        self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
-
-    @property
-    def current_size_bytes(self) -> int:
-        return self._current_size.value
-
-    @current_size_bytes.setter
-    def current_size_bytes(self, value: int):
-        self._current_size.value = value
-
-    @property
-    def handle_counter(self) -> int:
-        return self._handle_counter.value
-
-    @handle_counter.setter
-    def handle_counter(self, value: int):
-        self._handle_counter.value = value
-
-    @contextlib.asynccontextmanager
-    async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
-        """
-        Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
-
-        :param descr: allocate a tensor of this size, dtype, etc
-
-        :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
-        Furthermore, it can be called concurrently with at most one use_cache call in runtime.
-        """
-        assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
-        assert descr.device is None and descr
-        allocated_handle = None
-        allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
-        try:
-            async with hivemind.utils.enter_asynchronously(self.lock_metadata):
-                if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    raise AllocationFailed(
-                        f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
-                        f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
-                    )
-
-                allocated_handle = int(self.handle_counter)
-                self.current_size_bytes += allocated_size_bytes
-                self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                self._pending_messages.value += 1
-                self._pipe_send.send((allocated_handle, descr))
-
-            yield allocated_handle
-        finally:
-            if allocated_handle is not None:
-                async with hivemind.utils.enter_asynchronously(self.lock_metadata):
-                    self._pending_messages.value += 1
-                    self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
-                    self.current_size_bytes -= allocated_size_bytes
-
-    @contextlib.contextmanager
-    def use_cache(self, handle: Handle) -> torch.Tensor:
-        """
-        Return a tensor that was previously allocated with try_allocate_cache,
-
-        :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
-        However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
-        """
-        assert os.getpid() == self.runtime_pid
-        # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
-
-        with self.lock_metadata:
-            if self._allocated_tensors is None:
-                self._allocated_tensors = {}
-
-            # read creation/deletion requests from connection handlers
-            for i in range(int(self._pending_messages.value)):
-                recv_handle, recv_data = self._pipe_recv.recv()
-                self._pending_messages.value -= 1
-                if isinstance(recv_data, TensorDescriptor):
-                    self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
-                elif recv_data is None:
-                    if recv_handle not in self._allocated_tensors:
-                        logger.warning(
-                            f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
-                        )
-                    self._allocated_tensors.pop(recv_handle, None)
-                else:
-                    logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
-
-        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
-        yield self._allocated_tensors[handle]
-
-
-class AllocationFailed(Exception):
-    pass

+ 0 - 229
src/server/handler.py

@@ -1,229 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import contextlib
-from typing import AsyncIterator, Dict, Sequence
-
-import torch
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
-from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
-from hivemind.proto import runtime_pb2
-from hivemind.utils import as_aiter
-from hivemind.utils.asyncio import anext
-from hivemind.utils.streaming import split_for_streaming
-
-from src.data_structures import CHAIN_DELIMITER, ModuleUID
-from src.server.backend import MAX_LENGTH, TransformerBackend
-
-
-class TransformerConnectionHandler(ConnectionHandler):
-    """Handles three request types: forward, backward and forward-incremental (inference)"""
-
-    module_backends: Dict[ModuleUID, TransformerBackend]
-
-    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
-        super().__init__(dht, module_backends)
-        for module_backend in self.module_backends.values():
-            assert isinstance(module_backend, TransformerBackend)
-
-    async def rpc_inference(
-        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
-    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
-        """Compute a single step of inference using attention cache; update attention cache accordingly."""
-        try:
-            print("OPENED RPC_INFERENCE")
-            request = await anext(requests)
-            requested_uids = self._check_header(request)
-            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-
-            cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, prefix_length]
-            prefix_length = 0
-
-            async with self._allocate_caches(requested_backends) as cache_handles:
-                assert len(cache_handles) == len(requested_backends)
-                while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-
-                    # run request tensors through all requested modules, update caches
-                    for backend, cache_handle in zip(requested_backends, cache_handles):
-                        cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
-                        assert (
-                            len(hidden_states) == 1 and hidden_states[0].ndim == 3
-                        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-
-                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
-                        assert isinstance(hidden_states, (list, tuple))
-                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
-
-                    # serialize and send last layer outputs
-                    yield runtime_pb2.ExpertResponse(
-                        tensors=[
-                            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-                            for result, proto in zip(
-                                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
-                            )
-                        ]
-                    )
-
-                    # prepare for next step
-                    prefix_length += hidden_states[0].shape[1]
-                    request = await (anext(requests))
-        finally:
-            print("CLOSED RPC_INFERENCE")
-
-    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
-        # Parse request and prepare backends
-        hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
-
-        # Serialize the overall output and respond
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
-        return runtime_pb2.ExpertResponse(
-            tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-                for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
-            ]
-        )
-
-    async def rpc_forward_stream(
-        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
-    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
-        # Parse requests and prepare backends
-        uids_header, hidden_states = await self._gather_inputs(requests, context)
-        requested_uids = self._check_header_str(uids_header)
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-
-        # Run a chain of requested backends
-        for backend in requested_backends:
-            assert isinstance(hidden_states, (list, tuple))
-            assert (
-                len(hidden_states) == 1 and hidden_states[0].ndim == 3
-            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
-
-        # Serialize the overall output
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
-        serialized_output = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-            for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
-        ]
-
-        # Split the serialized_output for streaming and respond
-        output_split = [
-            part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-        ]
-        async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(tensors=[part])
-
-    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
-        # Parse requests and prepare backends
-        inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its output
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a chain of requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
-
-        # Serialize the overall grad_input and respond
-        return runtime_pb2.ExpertResponse(
-            tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-                for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
-            ]
-        )
-
-    async def rpc_backward_stream(
-        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
-    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
-        uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
-        inputs, grads = inputs_and_grads
-        requested_uids = self._check_header_str(uids_header)
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-
-        # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its outputs
-        inter_inputs = [inputs]
-        for backend in requested_backends[:-1]:
-            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-            inputs = await backend.forward_pool.submit_task(inputs)
-            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-            inputs = inputs[0]
-            inter_inputs.append(inputs)
-
-        # Run a backward chain for requested backends
-        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
-            inputs_and_grads = [inp, grads]
-            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert isinstance(grads, (list, tuple)) and len(grads) == 1
-            grads = grads[0]
-
-        # Serialize the overall grad_inputs
-        serialized_grad_inputs = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-            for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
-        ]
-        # Split the serialized_grad_inputs for streaming and respond
-        output_split = [
-            part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-        ]
-
-        async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(tensors=[part])
-
-    def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
-        """Check that the first request to rpc_inference is valid"""
-        uids = (request.uid or "").split(CHAIN_DELIMITER)
-        if not uids:
-            raise RuntimeError("User did not provide any uids")
-        for uid in uids:
-            if uid not in self.module_backends:
-                raise RuntimeError(f"Remote peer does not serve {uid}")
-        return tuple(uids)
-
-    def _check_header_str(self, header) -> Sequence[ModuleUID]:
-        """Check that the first request to rpc_inference is valid"""
-        uids = (header or "").split(CHAIN_DELIMITER)
-        if not uids:
-            raise RuntimeError("User did not provide any uids")
-        for uid in uids:
-            if uid not in self.module_backends:
-                raise RuntimeError(f"Remote peer does not serve {uid}")
-        return tuple(uids)
-
-    @contextlib.asynccontextmanager
-    async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
-        """Allocate memory caches for each transformer block, return cache handles"""
-        async with contextlib.AsyncExitStack() as stack:
-            handles = []
-            for backend in backends:
-                num_heads = backend.module.self_attention.num_heads
-                head_dim = backend.module.self_attention.head_dim
-
-                cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
-                # [key_or_value, batch_size, max_length, num_heads, head_dim]
-
-                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
-
-            yield handles

+ 0 - 255
src/server/server.py

@@ -1,255 +0,0 @@
-from __future__ import annotations
-
-import multiprocessing as mp
-import threading
-from typing import Dict, Optional, Sequence, Union
-
-import torch
-from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
-from hivemind.moe.server.dht_handler import DHTHandlerThread
-from hivemind.moe.server.layers import add_custom_models_from_file
-from hivemind.moe.server.runtime import Runtime
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
-from src import declare_active_modules
-from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
-from src.server.backend import TransformerBackend
-from src.server.cache import MemoryCache
-from src.server.handler import TransformerConnectionHandler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class Server(threading.Thread):
-    """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
-
-    def __init__(
-        self,
-        dht: DHT,
-        module_backends: Dict[str, TransformerBackend],
-        *,
-        device: torch.device,
-        num_connection_handlers: int = 8,
-        update_period: float = 30,
-        expiration: Optional[float] = None,
-        start: bool,
-        **kwargs,
-    ):
-        threading.Thread.__init__(self)
-        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
-        self.conn_handlers = [
-            TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
-        ]
-        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
-        self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends, dht, update_period, expiration, daemon=True
-        )
-        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    def run(self):
-        """
-        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
-        runs Runtime (self.runtime) to process incoming requests.
-        """
-        logger.info(f"Serving {len(self.module_backends)} blocks:")
-        for expert_name, backend in self.module_backends.items():
-            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
-
-        if not self.dht.is_alive():
-            self.dht.run_in_background(await_ready=True)
-
-        if self.module_backends:
-            self.dht_handler_thread.start()
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.start()
-
-        for process in self.conn_handlers:
-            if not process.is_alive():
-                process.start()
-            process.ready.result()
-
-        try:
-            self.runtime.run()
-        finally:
-            self.shutdown()
-
-    # noinspection PyMethodOverriding
-    @classmethod
-    def create(
-        cls,
-        prefix: Optional[str],
-        converted_model_name_or_path: str,
-        num_blocks: Optional[int] = None,
-        block_indices: Optional[str] = None,
-        num_handlers: Optional[int] = None,
-        min_batch_size: int = 1,
-        max_batch_size: int = 4096,
-        torch_dtype: str = "auto",
-        cache_size_bytes: Optional[int] = None,
-        device: Union[str, torch.device] = None,
-        initial_peers: Sequence[str] = (),
-        compression=CompressionType.NONE,
-        stats_report_interval: Optional[int] = None,
-        custom_module_path=None,
-        update_period: float = 30,
-        expiration: Optional[float] = None,
-        use_auth_token: Optional[str] = None,
-        *,
-        start: bool,
-        **kwargs,
-    ) -> Server:
-        """Create a server with one or more bloom blocks. See run_server.py for documentation."""
-        if custom_module_path is not None:
-            add_custom_models_from_file(custom_module_path)
-        if prefix is None:
-            prefix = converted_model_name_or_path
-            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
-                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
-                f"Please specify --prefix manually when starting a server"
-            )
-            logger.info(f"Automatic dht prefix: {prefix}")
-        assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
-        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
-        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
-
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        memory_cache = MemoryCache(device, cache_size_bytes)
-
-        if isinstance(torch_dtype, str):
-            torch_dtype = DTYPE_MAP[torch_dtype]
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-
-        if block_indices is not None:
-            try:
-                first_block_index, last_block_index = block_indices.split(":")
-                first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
-            except Exception as e:
-                logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
-                raise
-            block_indices = range(first_block_index, last_block_index)
-        else:
-            assert num_blocks is not None
-            block_indices = range(num_blocks)  # TODO replace with proper load balancing
-
-        block_config = DistributedBloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token
-        )
-
-        # initialize modules
-        blocks = {}
-        for block_index in block_indices:
-            module_uid = f"{prefix}.{block_index}"
-            block = load_pretrained_block(
-                converted_model_name_or_path,
-                block_index,
-                block_config,
-                torch_dtype=torch_dtype,
-                use_auth_token=use_auth_token,
-            )
-            for param in block.parameters():
-                param.requires_grad = False
-
-            blocks[module_uid] = TransformerBackend(
-                module_uid,
-                block,
-                memory_cache=memory_cache,
-                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
-                kwargs_schema={},
-                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
-
-        return cls(
-            dht,
-            blocks,
-            num_connection_handlers=num_handlers,
-            device=device,
-            stats_report_interval=stats_report_interval,
-            update_period=update_period,
-            expiration=expiration,
-            start=start,
-        )
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts Server in a background thread. if await_ready, this method will wait until background server
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
-
-    @property
-    def ready(self) -> mp.synchronize.Event:
-        """
-        An event (multiprocessing.Event) that is set when the server is ready to process requests.
-
-        Example
-        =======
-        >>> server.start()
-        >>> server.ready.wait(timeout=10)
-        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
-        """
-        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
-
-    def shutdown(self):
-        """
-        Gracefully terminate the server, process-safe.
-        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
-        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
-        """
-        self.ready.clear()
-
-        for process in self.conn_handlers:
-            process.terminate()
-            process.join()
-        logger.debug("Connection handlers terminated")
-
-        if self.module_backends:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
-
-        self.dht.shutdown()
-        self.dht.join()
-
-        logger.debug(f"Shutting down runtime")
-
-        self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
-
-
-class ModuleAnnouncerThread(threading.Thread):
-    """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
-
-    def __init__(
-        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
-    ):
-        super().__init__(**kwargs)
-        if expiration is None:
-            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
-        self.module_backends = module_backends
-        self.dht = dht
-        self.update_period = update_period
-        self.expiration = expiration
-        self.stop = threading.Event()
-
-    def run(self) -> None:
-        declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
-        while not self.stop.wait(self.update_period):
-            declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)

+ 0 - 44
tests/test_block_exact_match.py

@@ -1,44 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
-
-import hivemind
-import torch
-
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
-
-
-def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    remote_block = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    inputs = torch.randn(1, 8, 4096)
-    (outputs_forward,) = remote_block(inputs)
-
-    outputs_inference = []
-    with remote_block.inference_session() as sess:
-        for i in range(inputs.shape[1]):
-            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-    outputs_inference = torch.cat(outputs_inference, dim=1)
-
-    ref_block = load_pretrained_block(REF_NAME, REF_INDEX, torch_dtype=torch.float32)
-    (outputs_local,) = ref_block(inputs)
-
-    assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
-    assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

+ 0 - 59
tests/test_chained_forward_backward.py

@@ -1,59 +0,0 @@
-######
-# Warning:torch this test is a work in progress. It will be modified soon.
-# - if you want more stable tests, see test_block_exact_match
-# - if you want to figure out chained inference, ask yozh
-
-import os
-
-import hivemind
-import torch
-from hivemind.moe.expert_uid import ExpertInfo
-
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-
-
-# seq_length > 128: rpc_forward_stream & rpc_backward_stream
-# seq_length <= 128: rpc_forward & rpc_backward
-def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    (remote_block,) = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
-
-    ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
-    ]
-    inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
-    outputs_rpc = remote_block.forward(inputs)[0]
-    outputs_rpc.sum().backward()
-    grads_rpc = inputs.grad
-
-    inputs.grad = None
-    hidden_states = inputs
-    for ref_block in ref_blocks:
-        hidden_states = ref_block.forward(hidden_states)[0]
-    outputs_ref = hidden_states
-    outputs_ref.sum().backward()
-    grads_ref = inputs.grad
-
-    assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
-    assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)

+ 0 - 64
tests/test_chained_inference.py

@@ -1,64 +0,0 @@
-######
-# Warning:torch this test is a work in progress. It will be modified soon.
-# - if you want more stable tests, see test_block_exact_match
-# - if you want to figure out chained inference, ask yozh
-
-import os
-
-import hivemind
-import torch
-from hivemind.moe.expert_uid import ExpertInfo
-
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
-
-
-def test_remote_block_exact_match(atol_inference=1e-4):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    remote_block = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4", remote_block._info.peer_id)
-
-    inputs = torch.randn(1, 8, 4096)
-
-    outputs_inference = []
-    with remote_block.inference_session() as sess:
-        for i in range(inputs.shape[1]):
-            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-    outputs_inference = torch.cat(outputs_inference, dim=1)
-
-    ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-    ]
-    outputs_ref = []
-    caches = [None, None]
-    for i in range(inputs.shape[1]):
-        new_caches = []
-        hidden_states = inputs[:, i : i + 1, :]
-        for ref_block, cache in zip(ref_blocks, caches):
-            with torch.no_grad():
-                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
-                new_caches.append(new_cache)
-
-        outputs_ref.append(hidden_states)
-        caches = new_caches
-    outputs_ref = torch.cat(outputs_ref, dim=1)
-    assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)

+ 0 - 60
tests/test_full_model.py

@@ -1,60 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
-
-import torch
-import transformers
-from hivemind import get_logger, use_hivemind_log_handler
-
-from src.client.remote_model import DistributedBloomForCausalLM
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME")
-
-
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
-    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix=prefix)
-    assert len(model.transformer.h) == model.config.n_layer
-
-    test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
-    parallel_outputs = model.forward(test_inputs).logits
-    assert torch.all(torch.isfinite(parallel_outputs))
-    logger.info("Forward outputs are finite")
-
-    if REF_NAME:
-        ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-        # note: this creates a dummy mask to make the test compatible with older transformer versions
-        # prior to https://github.com/huggingface/transformers/pull/17837
-        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
-        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
-    else:
-        logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
-
-    embs = model.transformer.word_embeddings(test_inputs)
-    embs = model.transformer.word_embeddings_layernorm(embs)
-    recurrent_outputs = []
-    with model.transformer.h.inference_session() as sess:
-        for t in range(embs.shape[1]):
-            recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
-    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
-    recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-
-    dictionary = model.transformer.word_embeddings.weight.t()
-    recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-    recurrent_outputs = (recurrent_outputs @ dictionary).float()
-    assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
-    logger.info("Inference is consistent with forward")

+ 18 - 0
your_code_here.py

@@ -0,0 +1,18 @@
+import torch
+import torch.nn as nn
+from hivemind.moe.server.layers.custom_experts import register_expert_class
+
+
+@register_expert_class("ExampleModule", lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)))
+class ExampleModule(nn.Module):
+    def __init__(self, hid_dim):
+        super().__init__()
+        self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
+        self.ffn_output = nn.Linear(4 * hid_dim, hid_dim)
+        self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12)
+
+    def forward(self, x):
+        ffn_output = self.ffn(x)
+        ffn_output = torch.nn.functional.gelu(ffn_output)
+        ffn_output = self.ffn_output(ffn_output)
+        return self.layer_norm(x + ffn_output)