|
@@ -460,12 +460,12 @@ class LMHead(nn.Module):
|
|
|
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
|
|
|
|
|
|
word_embeddings = self.word_embeddings.weight
|
|
|
+ num_embeddings = self.word_embeddings.num_embeddings
|
|
|
|
|
|
- hidden_states = hidden_states.float()
|
|
|
- num_embeddings = word_embeddings.shape[0]
|
|
|
+ hidden_states = hidden_states.float()
|
|
|
output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
|
|
|
|
|
|
for i in range(0, num_embeddings, self.chunk_size):
|
|
|
- chunk = word_embeddings[i:i+self.chunk_size].float()
|
|
|
- output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk)
|
|
|
- return output
|
|
|
+ chunk = word_embeddings[i: i + self.chunk_size].float()
|
|
|
+ output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
|
|
|
+ return output
|