1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
| class AlbertModel(AlbertPreTrainedModel): config_class = AlbertConfig load_tf_weights = load_tf_weights_in_albert base_model_prefix = "albert" def __init__(self, config, add_pooling_layer=True): super().__init__(config)
self.config = config self.embeddings = AlbertEmbeddings(config) self.encoder = AlbertTransformer(config) if add_pooling_layer: self.pooler = nn.Linear(config.hidden_size, config.hidden_size) self.pooler_activation = nn.Tanh() else: self.pooler = None self.pooler_activation = None
self.init_weights()
def get_input_embeddings(self): return self.embeddings.word_embeddings
def set_input_embeddings(self, value): self.embeddings.word_embeddings = value
def _resize_token_embeddings(self, new_num_tokens): old_embeddings = self.embeddings.word_embeddings new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) self.embeddings.word_embeddings = new_embeddings return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune): """Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer, while [2,3] correspond to the two inner groups of the second hidden layer.
Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more information about head pruning """ for layer, heads in heads_to_prune.items(): group_idx = int(layer / self.config.inner_group_num) inner_group_idx = int(layer - group_idx * self.config.inner_group_num) self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2", output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings( input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) encoder_outputs = self.encoder( embedding_output, extended_attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, )
sequence_output = encoder_outputs[0]
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
|