Dmitry Baranchuk 3 роки тому
батько
коміт
fd0bf064f3
1 змінених файлів з 5 додано та 5 видалено
  1. 5 5
      src/bloom/model.py

+ 5 - 5
src/bloom/model.py

@@ -460,12 +460,12 @@ class LMHead(nn.Module):
         assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
 
         word_embeddings = self.word_embeddings.weight
+        num_embeddings = self.word_embeddings.num_embeddings
 
-        hidden_states = hidden_states.float()
-        num_embeddings = word_embeddings.shape[0]        
+        hidden_states = hidden_states.float()    
         output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
 
         for i in range(0, num_embeddings, self.chunk_size):
-            chunk = word_embeddings[i:i+self.chunk_size].float()
-            output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk)
-        return output
+            chunk = word_embeddings[i: i + self.chunk_size].float()
+            output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
+        return output