Optimize memory consumption when calculating the loss · DAMO-NLP-SG/VideoLLaMA3@2126866 (original) (raw)
`@@ -23,6 +23,7 @@
`
23
23
`from transformers.generation.utils import GenerateOutput
`
24
24
`from transformers.modeling_outputs import CausalLMOutputWithPast
`
25
25
``
``
26
`+
from videollama3.constants import IGNORE_INDEX
`
26
27
`from .videollama3_arch import Videollama3MetaForCausalLM, Videollama3MetaModel
`
27
28
``
28
29
``
`@@ -98,20 +99,56 @@ def forward(
`
98
99
`modals=modals,
`
99
100
` )
`
100
101
``
101
``
`-
return super().forward(
`
``
102
`+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
`
``
103
`+
output_hidden_states = (
`
``
104
`+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
`
``
105
`+
)
`
``
106
`+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
`
``
107
+
``
108
`+
decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
`
``
109
`+
outputs = self.model(
`
102
110
`input_ids=input_ids,
`
103
111
`attention_mask=attention_mask,
`
104
112
`position_ids=position_ids,
`
105
113
`past_key_values=past_key_values,
`
106
114
`inputs_embeds=inputs_embeds,
`
107
``
`-
labels=labels,
`
108
115
`use_cache=use_cache,
`
109
116
`output_attentions=output_attentions,
`
110
117
`output_hidden_states=output_hidden_states,
`
111
118
`return_dict=return_dict,
`
112
119
`cache_position=cache_position,
`
113
``
`-
num_logits_to_keep=num_logits_to_keep,
`
114
``
`-
**loss_kwargs,
`
``
120
`+
)
`
``
121
+
``
122
`+
hidden_states = outputs[0]
`
``
123
+
``
124
`+
loss = None
`
``
125
`+
if labels is not None:
`
``
126
`+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
`
``
127
`+
shift_labels = labels[..., 1:].contiguous()
`
``
128
`+
mask = shift_labels != IGNORE_INDEX
`
``
129
`+
shift_hidden_states = shift_hidden_states[mask]
`
``
130
`+
shift_labels = shift_labels[mask]
`
``
131
`+
logits = self.lm_head(shift_hidden_states)
`
``
132
`+
if "num_items_in_batch" in loss_kwargs:
`
``
133
`+
loss = nn.functional.cross_entropy(logits, shift_labels, ignore_index=IGNORE_INDEX, reduction="sum")
`
``
134
`+
loss = loss / loss_kwargs["num_items_in_batch"]
`
``
135
`+
else:
`
``
136
`+
loss = nn.functional.cross_entropy(logits, shift_labels, ignore_index=IGNORE_INDEX)
`
``
137
+
``
138
`+
else:
`
``
139
`+
Only compute necessary logits, and do not upcast them to float if we are not computing the loss
`
``
140
`+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
`
``
141
+
``
142
`+
if not return_dict:
`
``
143
`+
output = (logits,) + outputs[1:]
`
``
144
`+
return (loss,) + output if loss is not None else output
`
``
145
+
``
146
`+
return CausalLMOutputWithPast(
`
``
147
`+
loss=loss,
`
``
148
`+
logits=logits,
`
``
149
`+
past_key_values=outputs.past_key_values,
`
``
150
`+
hidden_states=outputs.hidden_states,
`
``
151
`+
attentions=outputs.attentions,
`
115
152
` )
`
116
153
``
117
154
`@torch.no_grad()
`