justheuristic преди 3 години
родител
ревизия
f8b7aaece4
променени са 2 файла, в които са добавени 9 реда и са изтрити 4 реда
  1. 2 0
      src/bloom/model.py
  2. 7 4
      src/client/remote_sequential.py

+ 2 - 0
src/bloom/model.py

@@ -205,6 +205,8 @@ class BloomModel(BloomPreTrainedModel):
 
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        if position_ids is not None:
+            logger.warning("position_ids are ignored in this bloom implementation")
         elif input_ids is not None:
             input_shape = input_ids.size()
             input_ids = input_ids.view(-1, input_shape[-1])

+ 7 - 4
src/client/remote_sequential.py

@@ -20,7 +20,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class RemoteSequential(nn.Sequential):
+class RemoteSequential(nn.Module):
     """
     A sequence of transformer blocks hosted by the swarm.
     """
@@ -41,9 +41,9 @@ class RemoteSequential(nn.Sequential):
         self.max_retries = max_retries
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
 
-        self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
-        logger.debug(f"Remote block uids: {self.block_uids}")
-        self.remote_sequence_info = RemoteSequenceInfo(dht, self.block_uids)
+        block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
+        logger.debug(f"Remote block uids: {block_uids}")
+        self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -72,6 +72,9 @@ class RemoteSequential(nn.Sequential):
         for block_index in range(self.config.n_layer):
             yield self[block_index]
 
+    def __len__(self):
+        return len(self.remote_sequence_info)
+
     def inference_session(self) -> RemoteSequentialInferenceSession:
         self.remote_sequence_info.update_()
         return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)