huggingface/diffusers (original) (raw)

Describe the bug

Running diffusers.utils.export_to_video() on the output of HunyuanVideoPipeline results in

/app/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

After adding some checks to numpy_to_pil() in image_processor.py I have confirmed that the output contains NaN values

  File "/app/pipeline.py", line 37, in <module>
    output = pipe(
             ^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py", line 677, in __call__
    video = self.video_processor.postprocess_video(video, output_type=output_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/video_processor.py", line 103, in postprocess_video
    batch_output = self.postprocess(batch_vid, output_type)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/image_processor.py", line 823, in postprocess
    return self.numpy_to_pil(image)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/image_processor.py", line 158, in numpy_to_pil
    raise ValueError("Image array contains NaN values")
ValueError: Image array contains NaN values

Reproduction

import os import time

import torch from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from diffusers.utils import export_to_video from huggingface_hub import snapshot_download from torch.profiler import ProfilerActivity, profile, record_function

os.environ["TOKENIZERS_PARALLELISM"] = "false"

MODEL_ID = "tencent/HunyuanVideo" PROMPT = "a whale shark floating through outer space" profile_dir = os.environ.get("PROFILE_OUT_PATH", "./") profile_file_name = os.environ.get("PROFILE_OUT_FILE_NAME", "hunyuan_profile.json") profile_path = os.path.join(profile_dir, profile_file_name)

transformer = HunyuanVideoTransformer3DModel.from_pretrained( MODEL_ID, subfolder="transformer", torch_dtype=torch.float16, revision="refs/pr/18" ) pipe = HunyuanVideoPipeline.from_pretrained( MODEL_ID, transformer=transformer, torch_dtype=torch.float16, revision="refs/pr/18" ) pipe.vae.enable_tiling() pipe.to("cuda")

print(f"\nStarting profiling of {MODEL_ID}\n")

with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True ) as prof: with record_function("model_inference"): output = pipe( prompt=PROMPT, height=320, width=512, num_frames=61, num_inference_steps=30, )

Export and print profiling results

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) prof.export_chrome_trace(profile_path) print(f"{profile_file_name} ready")

export video

video = output.frames[0]

print(" ====== raw video matrix =====") print(video) print()

print(" ====== Exporting video =====") export_to_video(video, "hunyuan_example.mp4", fps=15) print()

Logs

No response

System Info

GPU: AMD MI300X

ARG BASE_IMAGE=python:3.11-slim FROM ${BASE_IMAGE}

ENV PYTHONBUFFERED=true ENV CUDA_VISIBLE_DEVICES=0

WORKDIR /app

Install tools

RUN apt-get update &&
apt-get install -y --no-install-recommends
git
libgl1-mesa-glx
libglib2.0-0
libsm6
libxext6
libxrender-dev
libfontconfig1
ffmpeg
build-essential &&
rm -rf /var/lib/apt/lists/*

install ROCm pytorch and python dependencies

RUN python -m pip install --no-cache-dir
torch torchvision --index-url https://download.pytorch.org/whl/rocm6.2 &&
python -m pip install --no-cache-dir
accelerate transformers sentencepiece protobuf opencv-python imageio imageio-ffmpeg

install diffusers from source to include newest pipeline classes

COPY diffusers diffusers RUN cd diffusers &&
python -m pip install -e .

Copy the profiling script

ARG PIPELINE_FILE COPY ${PIPELINE_FILE} pipeline.py

run the script

CMD ["python", "pipeline.py"]

Who can help?

@DN6 @a-r-r-o-w