Stable Diffusion XL Video Pipeline: An Advanced Text-to-Video Model

July 21, 2024

Introduction to the Model

The Stable Diffusion XL Video Pipeline model is an extension of the text-to-image model, designed to generate infinite videos using previous frames to condition the diffusion of subsequent frames. This approach ensures temporal consistency in the generated videos.

Main Components

StableDiffusionXLVideoPipeline: The main class that handles the entire video generation process. Temporal Transformer: Used to capture temporal dependencies between frames. Temporal Conditioner Transformer: Performs cross-attention with past frames to ensure temporal consistency.

Model Architecture

StableDiffusionXLVideoPipeline

The StableDiffusionXLVideoPipeline class manages the entire generation process, from encoding the prompt to producing the video frames. Here is an overview of the main functions and components of the class:

Module Initialization

The init function registers the necessary modules, including the VAE, text encoders, tokenizers, the UNet 3D model, and the scheduler.

class StableDiffusionXLVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
    def **init**(self, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, scheduler):
        super().**init**()
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            unet=unet,
            scheduler=scheduler
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.default_sample_size = self.unet.config.sample_size

Prompt Encoding

The encode_prompt function handles the encoding of the text prompt, including negative prompts if present. It uses the tokenizers and text encoders to convert the text into embeddings usable by the model.

def encode_prompt(self, prompt, device=None, do_classifier_free_guidance=True, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None):
    tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
    text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
    prompt_embeds_list = []

    for tokenizer, text_encoder in zip(tokenizers, text_encoders):
        text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        text_input_ids = text_inputs.input_ids.to(device)
        prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
        prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)

    if do_classifier_free_guidance and negative_prompt is not None:
        negative_prompt_embeds_list = []
        for tokenizer, text_encoder in zip(tokenizers, text_encoders):
            negative_inputs = tokenizer(negative_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
            negative_input_ids = negative_inputs.input_ids.to(device)
            negative_prompt_embeds = text_encoder(negative_input_ids, output_hidden_states=True).hidden_states[-2]
            negative_prompt_embeds_list.append(negative_prompt_embeds)

        negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=-1)
    else:
        negative_prompt_embeds = torch.zeros_like(prompt_embeds)

    return prompt_embeds, negative_prompt_embeds

Preparing Latents

The prepare_latents function prepares the latents for the diffusion process, handling the generation of initial noise if necessary.

def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
    shape = (batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor)
    if latents is None:
        latents = torch.randn(shape, device=device, dtype=dtype)
    else:
        latents = latents.to(device)
        latents = latents * self.scheduler.init_noise_sigma
        return latents

Main Call Function

The call function handles the video generation process, including prompt encoding, latents preparation, and the diffusion loop.

@torch.no_grad()
def **call**(self, prompt, negative_prompt=None, height=None, width=None, latent_image=None, device=None, num_frames=16, num_inference_steps=50, guidance_scale=9.0, conditioning_strength=0.0, eta=0.0, generator=None, latents=None, previous_latents=None):
    do_classifier_free_guidance = guidance_scale > 1.0
    prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, device, do_classifier_free_guidance, negative_prompt)

    if latents is None:
        latents = self.prepare_latents(
            batch_size=1,
            num_channels_latents=self.unet.config.in_channels,
            num_frames=num_frames,
            height=height,
            width=width,
            dtype=prompt_embeds.dtype,
            device=device,
            generator=generator
        )

    for t in self.scheduler.timesteps:
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            conditioning_hidden_states=previous_latents
        )

        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

    return latents

Temporal Transformer

The Temporal Transformer captures the temporal dependencies between frames, allowing the model to maintain temporal consistency in the generated video. It uses sinusoidal embeddings to encode the temporal position of frames.

class TemporalTransformer(ModelMixin, ConfigMixin):
@register*to_config
    def **init**(self, num_attention_heads=16, attention_head_dim=88, in_channels=None, out_channels=None, num_layers=1, dropout=0.0):
        super().**init**()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim
        self.in_channels = in_channels
        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        self.proj_in = nn.Linear(in_channels, inner_dim)
        self.positional_encoding = SinusoidalEmbeddings()
        self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, dropout=dropout) for * in range(num_layers)])
        self.proj_out = nn.Linear(inner_dim, in_channels)

        def forward(self, hidden_states, return_dict=True):
            if hidden_states.size(2) <= 1:
                if not return_dict:
                    return (hidden_states,)
                return TemporalTransformerOutput(sample=hidden_states)

            batch, channel, frames, height, width = hidden_states.shape
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
            hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch * frames, channel, height, width)
            hidden_states = self.proj_in(hidden_states)
            hidden_states = hidden_states + self.positional_encoding(hidden_states)

            for block in self.transformer_blocks:
                hidden_states = block(hidden_states=hidden_states)

            hidden_states = self.proj_out(hidden_states).view(batch, frames, channel, height, width).permute(0, 2, 1, 3, 4)
            hidden_states += residual

            return TemporalTransformerOutput(sample=hidden_states)

Temporal Conditioner Transformer

The Temporal Conditioner Transformer uses cross-attention with past frames to ensure that each new frame is consistent with the previous ones, maintaining a smooth and natural transition in the video. This module is crucial for ensuring temporal consistency in long-duration videos.

class TemporalConditionerTransformer(ModelMixin, ConfigMixin):
@register*to_config
    def **init**(self, num_attention_heads=16, attention_head_dim=88, in_channels=None, cross_attention_dim=2560, num_layers=1, only_cross_attention=True, dropout=0.0):
        super().**init**()
        self.conv_in = nn.Conv3d(4, cross_attention_dim, kernel_size=(1, 1, 1))
        self.ln = nn.LayerNorm(in_channels)
        self.proj_in = nn.Linear(in_channels, cross_attention_dim)
        self.positional_encoding = ExponentialEmbeddings()
        self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(cross_attention_dim, num_attention_heads, attention_head_dim, dropout=dropout) for * in range(num_layers)])
        self.proj_out = nn.Linear(cross_attention_dim, in_channels)

    def forward(self, hidden_states, encoder_hidden_states, return_dict=True):
        h_b, h_c, h_f, h_h, h_w = hidden_states.shape
        residual = hidden_states
        encoder_hidden_states = torch.nn.functional.interpolate(encoder_hidden_states, size=(encoder_hidden_states.shape[2], h_h, h_w), mode='trilinear', align_corners=False).repeat_interleave(repeats=h_b, dim=0)
        encoder_hidden_states = self.conv_in(encoder_hidden_states).permute(0, 2, 1, 3, 4).reshape(h_b * h_h * h_w, h_f, encoder_hidden_states.shape[1])
        hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(h_b * h_h * h_w, h_f, h_c)
        hidden_states = self.ln(hidden_states)
        hidden_states = self.proj_in(hidden_states)
        encoder_hidden_states += self.positional_encoding(encoder_hidden_states)
        hidden_states += self.positional_encoding(hidden_states, reverse=True)

        for block in self.transformer_blocks:
            hidden_states = block(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states)

        hidden_states = self.proj_out(hidden_states).view(h_b, h_h, h_w, h_f, h_c).permute(0, 4, 3, 1, 2)
        hidden_states += residual

        return TemporalConditionerTransformerOutput(sample=hidden_states)

Usage Example

Here is an example of how to use the Stable Diffusion XL Video Pipeline model to generate a video from a text prompt.

import torch

def generate_video():
    model_path = "path/to/pretrained/model"
    pipeline = StableDiffusionXLVideoPipeline.from_pretrained(pretrained_model_name_or_path=model_path).to("cuda")
    prompt = "A beautiful sunset over the sea"
    video = pipeline(prompt=prompt, num_frames=24, num_inference_steps=50, guidance_scale=9.0)
    save_video(video, "output_video.mp4")

def save_video(video_tensor, filename): # Function to save the video to file # ...

if **name** == "**main**":
    generate_video()

Pipeline Preparation

The initialize_pipeline function loads the pre-trained model and sets up the DPMSolverMultistepScheduler with specified parameters to ensure efficient inference.

def initialize_pipeline(model, device="cuda"):
    pipeline = StableDiffusionXLVideoPipeline.from_pretrained(pretrained_model_name_or_path=model).to(device=device)
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
    pipeline.scheduler.config,
    timestep_spacing="trailing",
    beta_schedule="scaled_linear",
    beta_start=0.00082,
    beta_end=0.014,
    algorithm_type="sde-dpmsolver++"
)
    pipeline.vae.enable_slicing()
    return pipeline

Inference

The inference function runs the entire video generation process. It uses the pipeline to create the frames, normalizes them, and concatenates them to form a coherent video. Then, it saves the generated video to a file.

@torch.inference_mode()
def inference(
        pretrained_model_path,
        prompt,
        negative_prompt=None,
        width=768,
        height=768,
        init_image=None,
        use_init_image=False,
        num_frames=4,
        num_conditioning_frames=4,
        conditioning_strength=0.0,
        num_repeats=4,
        fps=6,
        num_inference_steps=50,
        guidance_scale=15,
        device="cuda",
        output_folder="videos",
        seed=None
    ):

    if seed is not None:
    set_seed(seed)
    os.makedirs(output_folder, exist_ok=True)
    with torch.autocast(device, dtype=torch.float16):
        pipeline = initialize_pipeline(pretrained_model_path, device)
        latent_image = None
        if init_image is not None and use_init_image:
            image = Image.open(init_image).convert('RGB')
            transform = transforms.Compose([
                transforms.Resize((height, width)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            image = transform(image).to(device=device).unsqueeze(0).unsqueeze(0).permute(0, 2, 1, 3, 4)
            latent_image = encode(image, pipeline.vae).repeat_interleave(repeats=num_frames, dim=2)

        out_file = f"{output_folder}/{prompt}.mp4"
        encoded_out_file = f"{output_folder}/{prompt}_encoded.mp4"
        previous_latents, stacked_latents = None, None

        for i in range(num_repeats + 1):
            with torch.no_grad():
                latents = pipeline(
                    prompt,
                    negative_prompt=negative_prompt,
                    device=device,
                    width=width,
                    height=height,
                    latent_image=latent_image,
                    conditioning_strength=conditioning_strength,
                    num_frames=num_frames,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    previous_latents=previous_latents
                )
            stacked_latents = torch.cat((stacked_latents, latents), dim=2) if i > 0 else latents
            stacked_latents = normalize_latents(stacked_latents)
            previous_latents = stacked_latents[:, :, -num_conditioning_frames:, :, :]

        stacked_latents = stacked_latents[:, :, num_frames:, :, :]
        tensor = decode(stacked_latents, pipeline.vae)
        save_video(tensor, out_file, fps)

        try:
            encode_video(out_file, encoded_out_file, get_video_height(out_file))
            os.remove(out_file)
        except:
            pass

Encoding and Decoding Latents

Functions encode and decode to convert tensors into latents and vice versa using the VAE.

def encode(tensor, vae):
    tensor = tensor.float().permute(0, 2, 1, 3, 4).reshape(tensor.shape[0] _ tensor.shape[2], tensor.shape[1], tensor.shape[3], tensor.shape[4])
    latents = vae.encode(tensor).latent_dist.sample() _ vae.config.scaling_factor
    return latents.view(tensor.shape[0], -1, tensor.shape[1], tensor.shape[3], tensor.shape[4]).permute(0, 2, 1, 3, 4)

def decode(latents, vae):
    latents = (1 / vae.config.scaling_factor) _ latents.permute(0, 2, 1, 3, 4).reshape(latents.shape[0] _ latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4])
    image = vae.decode(latents).sample
    return image.view(latents.shape[0] // latents.shape[2], latents.shape[2], -1, image.shape[2], image.shape[3]).permute(0, 2, 1, 3, 4).float()

Saving the Video

The save_video function normalizes and saves the generated video to a file.

def save_video(normalized_tensor, output_path, fps=30):
    denormalized_frames = denormalize(normalized_tensor)
    height, width = denormalized_frames.shape[1:3]
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(\*'mp4v'), fps, (width, height))
    for frame in denormalized_frames:
    out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    out.release()

Normalizing and Denormalizing

Utility functions to normalize and denormalize tensors.

def normalize_latents(latents):
    return (latents - latents.mean()) / (latents.std() + 1e-8)

def denormalize(normalized_tensor):
    denormalized = (normalized_tensor + 1.0) * 127.5
    return torch.clamp(denormalized, 0, 255).to(torch.uint8).permute(1, 2, 3, 0).numpy()

Setting the Seed

Function to set the seed for reproducibility of results.

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Conclusion

The Stable Diffusion XL Video Pipeline model represents a significant advancement in text-to-video generation, using advanced techniques such as temporal and conditional transformers. This approach allows for the creation of long-duration and consistent videos, opening up new possibilities for creativity and the application of artificial intelligence.

For further details and to explore the complete code, you can consult the project repository.