|
@@ -187,8 +187,12 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
- super().__init__(config)
|
|
|
- self.transformer = DistributedBloomModel(config)
|
|
|
+ BloomPreTrainedModel.__init__(self, config)
|
|
|
+ self.num_labels = config.num_labels
|
|
|
|
|
|
+ self.transformer = DistributedBloomModel(config)
|
|
|
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
|
|
+
|
|
|
# Initialize weights and apply final processing
|
|
|
self.post_init()
|
|
|
+
|