add: train to text image with sdxl script. by sayakpaul · Pull Request #4505 · huggingface/diffusers (original) (raw)
here is a naive implementation that uses multiprocessing (lol, it does not, i was thinking of my data loader) but it even has a progress bar! :D
import hashlib, os, torch, logging from tqdm import tqdm from PIL import Image import torchvision.transforms as transforms
logger = logging.getLogger("VAECache") logger.setLevel("INFO")
class VAECache: def init(self, vae, accelerator, cache_dir="vae_cache", resolution: int = 1024): self.vae = vae self.vae.enable_slicing() self.accelerator = accelerator self.cache_dir = cache_dir self.resolution = resolution os.makedirs(self.cache_dir, exist_ok=True)
def create_hash(self, filename):
# Create a sha256 hash
sha256_hash = hashlib.sha256()
# Feed the hash function with the filename
sha256_hash.update(filename.encode())
# Get the hexadecimal representation of the hash
return sha256_hash.hexdigest()
def save_to_cache(self, filename, embeddings):
torch.save(embeddings, filename)
def load_from_cache(self, filename):
return torch.load(filename)
def encode_image(self, pixel_values, filepath: str):
file_hash = self.create_hash(filepath)
filename = os.path.join(self.cache_dir, file_hash + ".pt")
logger.debug(f'Created file_hash {file_hash} from filepath {filepath} for resulting .pt filename.')
if os.path.exists(filename):
latents = self.load_from_cache(filename)
logger.debug(
f"Loading latents of shape {latents.shape} from existing cache file: {filename}"
)
else:
with torch.no_grad():
latents = self.vae.encode(
pixel_values.unsqueeze(0).to(
self.accelerator.device, dtype=torch.bfloat16
)
).latent_dist.sample()
logger.debug(
f"Using shape {latents.shape}, creating new latent cache: {filename}"
)
latents = latents * self.vae.config.scaling_factor
logger.debug(f"Latent shape after re-scale: {latents.shape}")
self.save_to_cache(filename, latents.squeeze())
output_latents = latents.squeeze().to(
self.accelerator.device, dtype=self.vae.dtype
)
logger.debug(f"Output latents shape: {output_latents.shape}")
return output_latents
def process_directory(self, directory):
# Define a transform to convert the image to tensor
transform = transforms.ToTensor()
# Get a list of all the files to process (customize as needed)
files_to_process = []
logger.debug(f"Beginning processing of VAECache directory {directory}")
for subdir, _, files in os.walk(directory):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
logger.debug(f"Discovered image: {os.path.join(subdir, file)}")
files_to_process.append(os.path.join(subdir, file))
# Iterate through the files, displaying a progress bar
for filepath in tqdm(files_to_process, desc="Processing images"):
# Create a hash based on the filename
file_hash = self.create_hash(filepath)
filename = os.path.join(self.cache_dir, file_hash + ".pt")
# If processed file already exists, skip processing for this image
if os.path.exists(filename):
logger.debug(
f"Skipping processing for {filepath} as cached file {filename} already exists."
)
continue
# Open the image using PIL
try:
logger.debug(f"Loading image: {filepath}")
image = Image.open(filepath)
image = image.convert("RGB")
image = self._resize_for_condition_image(image, self.resolution)
except Exception as e:
logger.error(f"Encountered error opening image: {e}")
os.remove(filepath)
continue
# Convert the image to a tensor
try:
pixel_values = transform(image).to(
self.accelerator.device, dtype=self.vae.dtype
)
except OSError as e:
logger.error(f"Encountered error converting image to tensor: {e}")
continue
# Process the image with the VAE
self.encode_image(pixel_values, filepath)
logger.debug(f"Processed image {filepath}")
def _resize_for_condition_image(self, input_image: Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
aspect_ratio = round(W / H, 3)
msg = f"Inspecting image of aspect {aspect_ratio} and size {W}x{H} to "
if W < H:
W = resolution
H = int(resolution / aspect_ratio) # Calculate the new height
elif H < W:
H = resolution
W = int(resolution * aspect_ratio) # Calculate the new width
if W == H:
W = resolution
H = resolution
msg = f"{msg} {W}x{H}."
logger.debug(msg)
img = input_image.resize((W, H), resample=Image.BICUBIC)
return img