In [3]:
import asyncio
from typing import Sequence, Tuple, Iterable, List
from tqdm.auto import trange

import torch
import hivemind
import petals

from petals.server.handler import TransformerConnectionHandler, split_for_streaming
from petals.client import RemoteSequenceManager, ClientConfig
from petals.client.remote_forward_backward import DEFAULT_MAX_MSG_SIZE, iter_as_aiter, aiter_with_timeout, deserialize_tensor_stream
from petals.data_structures import ModuleUID, PeerID, CHAIN_DELIMITER, UID_DELIMITER
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs

from hivemind.compression import serialize_torch_tensor
from hivemind.utils import MSGPackSerializer, nested_flatten
from hivemind.proto import runtime_pb2

_END_OF_STREAM_KEY = "_EOS"


async def pack_as_expert_requests(uid, flat_tensors, codecs, metadata):
    # Asynchronous serialization
    loop = asyncio.get_running_loop()
    serialized_tensors = await asyncio.gather(
        *(
            loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
            for tensor, compression in zip(flat_tensors, codecs)
        )
    )

    parts = [
        tensor_part for tensor in serialized_tensors
        for tensor_part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
    ]
    if len(parts) > 1:
        serialized_metadata = MSGPackSerializer.dumps(metadata)
    serialized_metadata_last_piece = MSGPackSerializer.dumps(dict(metadata, **{_END_OF_STREAM_KEY: True}))
    
    return [
        runtime_pb2.ExpertRequest(
            uid=uid, tensors=[tensor_part], 
            metadata=serialized_metadata if i != len(parts) - 1 else serialized_metadata_last_piece)
        for i, tensor_part in enumerate(parts)
    ]
    
async def run_remote_forward_backward(
    sequence_manager: RemoteSequenceManager,
    peer_id: PeerID,
    span_uids: Sequence[ModuleUID],
    *args: torch.Tensor,
    **kwargs: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
    """
    Serializes input tensors and calls "rpc_forward_backward" on a remote server.
    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
    """
    merged_uid = CHAIN_DELIMITER.join(span_uids)
    stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
    flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
    metadata = sequence_manager.get_request_metadata("rpc_forward", args_structure, uids=span_uids, *args, peer_id=peer_id, **kwargs) #TODO fix metadata api
    #codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
    codecs = [runtime_pb2.CompressionType.NONE for _ in args]  #TODO replace with proper compression
    flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
    args_structure = metadata.setdefault("args_structure", args_structure)
    if codecs is None:
        codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
    else:
        codecs = list(nested_flatten(codecs))
        assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"


    # call RPC on remote server
    size = sum(t.element_size() * t.nelement() for t in flat_tensors)
    # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
    
    ### HERE BEGINS INLINED REQUEST SENDER  
    # used to look like this:
    # output_tensors = await _run_forward_part(
    #     merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata
    # )
    config = sequence_manager.config
    assert _END_OF_STREAM_KEY not in metadata
    forward_requests = await pack_as_expert_requests(merged_uid, flat_tensors, codecs, metadata)
    backward_codecs = [runtime_pb2.CompressionType.NONE]  #TODO replace with proper compression
    fake_grad_outputs = torch.randn_like(flat_tensors[0])
    _, backward_args_structure = pack_args_kwargs(args[0], fake_grad_outputs, *args[1:], **kwargs)
    backward_metadata = dict(metadata, args_structure=backward_args_structure)
    
    grad_requests = await pack_as_expert_requests(merged_uid, (fake_grad_outputs,), backward_codecs, backward_metadata)
    
    received_outputs = asyncio.Event()

    async def iterate_inputs():
        for request in forward_requests:
            yield request
        print("WAITING FOR OUTPUTS")
        await received_outputs.wait()
        print("RECEIVED OUTPUTS - SENDING GRADS")
        for request in grad_requests:
            yield request
        print("SENT GRADS")

    async def _wrap_input_stream(stream):
        async for expert_request in stream:
            yield expert_request
            if not expert_request.metadata:
                continue #TODO write more generally
            metadata = MSGPackSerializer.loads(expert_request.metadata)
            print(metadata)
            if metadata.get(_END_OF_STREAM_KEY):
                break

    print("CALLING stub.rpc_forward_stream on serialized inputs", iterate_inputs())
    outputs_stream = await asyncio.wait_for(stub.rpc_forward_backward_stream(iterate_inputs()), config.connect_timeout)
    outputs_stream = aiter_with_timeout(outputs_stream, config.request_timeout)
    
    output_hidden_states = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))
    received_outputs.set()

    grad_inputs = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))
    print("RECEIVED GRAD INPUTS")
    #TODOreturn output_hidden_states, grads

    ####
    
    # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591
    requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
    output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_hidden_states]
    return output_tensors, grad_inputs


In [5]:
INITIAL_PEERS = ['/ip4/127.0.0.1/tcp/1337/p2p/QmRTdR9XmTHNXKiwtqRJ4i7tNofnmFrxkufBefguZUyXej']
peer_id_string = INITIAL_PEERS[0].split("/")[-1]
model_name = "Maykeye/TinyLLama-v0"

model_config = petals.DistributedLlamaConfig.from_pretrained(model_name)
block_uids = [
    f"{model_config.dht_prefix}{UID_DELIMITER}{i}"
    for i in range(model_config.num_hidden_layers)
]

block_in_use = block_uids[0:2]

try:
    dht = hivemind.DHT(start=True, client_mode=True, initial_peers=INITIAL_PEERS)
    sequence_manager = petals.RemoteSequenceManager(model_config, block_uids,  dht=dht)
    sequence_manager.rpc_info
    p2p = await dht.replicate_p2p()
    
    dummy_inputs = [
        torch.rand(1, 128, model_config.hidden_size, dtype=model_config.torch_dtype),
        torch.empty(0, dtype=model_config.torch_dtype),
    ]
    peer_id = hivemind.PeerID.from_base58(peer_id_string)
    for i in trange(1):
        (outputs,), grads = await run_remote_forward_backward(sequence_manager, peer_id, block_in_use, *dummy_inputs)
        print('outputs:', repr(outputs)[:50], '...')
    print("It works!")

finally:
    print("shutting down")
    await p2p.shutdown()
    dht.shutdown()  # it is okay to remove this clause, but you will be summoning a horde of daemons as you debug

Mar 17 18:37:25.661 [[1m[34mINFO[0m] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1
Mar 17 18:37:25.661 [[1m[34mINFO[0m] Using DHT prefix: TinyLLama-v0-hf
100%|██████████| 1/1 [00:00<00:00, 26.19it/s]

CALLING stub.rpc_forward_stream on serialized inputs <async_generator object run_remote_forward_backward.<locals>.iterate_inputs at 0x75eb8d134d60>
WAITING FOR OUTPUTS
{'_EOS': True}
RECEIVED OUTPUTS - SENDING GRADS
SENT GRADS
RECEIVED GRAD INPUTS
outputs: tensor([[[-0.0835,  0.3027,  0.2217,  ...,  1.1719 ...
It works!
shutting down



