Optimize memory consumption when calculating the loss · DAMO-NLP-SG/VideoLLaMA3@2126866 (original) (raw)

`@@ -23,6 +23,7 @@

`

23

23

`from transformers.generation.utils import GenerateOutput

`

24

24

`from transformers.modeling_outputs import CausalLMOutputWithPast

`

25

25

``

``

26

`+

from videollama3.constants import IGNORE_INDEX

`

26

27

`from .videollama3_arch import Videollama3MetaForCausalLM, Videollama3MetaModel

`

27

28

``

28

29

``

`@@ -98,20 +99,56 @@ def forward(

`

98

99

`modals=modals,

`

99

100

` )

`

100

101

``

101

``

`-

return super().forward(

`

``

102

`+

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

`

``

103

`+

output_hidden_states = (

`

``

104

`+

output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

`

``

105

`+

)

`

``

106

`+

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

`

``

107

+

``

108

`+

decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)

`

``

109

`+

outputs = self.model(

`

102

110

`input_ids=input_ids,

`

103

111

`attention_mask=attention_mask,

`

104

112

`position_ids=position_ids,

`

105

113

`past_key_values=past_key_values,

`

106

114

`inputs_embeds=inputs_embeds,

`

107

``

`-

labels=labels,

`

108

115

`use_cache=use_cache,

`

109

116

`output_attentions=output_attentions,

`

110

117

`output_hidden_states=output_hidden_states,

`

111

118

`return_dict=return_dict,

`

112

119

`cache_position=cache_position,

`

113

``

`-

num_logits_to_keep=num_logits_to_keep,

`

114

``

`-

**loss_kwargs,

`

``

120

`+

)

`

``

121

+

``

122

`+

hidden_states = outputs[0]

`

``

123

+

``

124

`+

loss = None

`

``

125

`+

if labels is not None:

`

``

126

`+

shift_hidden_states = hidden_states[..., :-1, :].contiguous()

`

``

127

`+

shift_labels = labels[..., 1:].contiguous()

`

``

128

`+

mask = shift_labels != IGNORE_INDEX

`

``

129

`+

shift_hidden_states = shift_hidden_states[mask]

`

``

130

`+

shift_labels = shift_labels[mask]

`

``

131

`+

logits = self.lm_head(shift_hidden_states)

`

``

132

`+

if "num_items_in_batch" in loss_kwargs:

`

``

133

`+

loss = nn.functional.cross_entropy(logits, shift_labels, ignore_index=IGNORE_INDEX, reduction="sum")

`

``

134

`+

loss = loss / loss_kwargs["num_items_in_batch"]

`

``

135

`+

else:

`

``

136

`+

loss = nn.functional.cross_entropy(logits, shift_labels, ignore_index=IGNORE_INDEX)

`

``

137

+

``

138

`+

else:

`

``

139

`+

Only compute necessary logits, and do not upcast them to float if we are not computing the loss

`

``

140

`+

logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

`

``

141

+

``

142

`+

if not return_dict:

`

``

143

`+

output = (logits,) + outputs[1:]

`

``

144

`+

return (loss,) + output if loss is not None else output

`

``

145

+

``

146

`+

return CausalLMOutputWithPast(

`

``

147

`+

loss=loss,

`

``

148

`+

logits=logits,

`

``

149

`+

past_key_values=outputs.past_key_values,

`

``

150

`+

hidden_states=outputs.hidden_states,

`

``

151

`+

attentions=outputs.attentions,

`

115

152

` )

`

116

153

``

117

154

`@torch.no_grad()

`