Stable Diffusion XL Video Pipeline: Un Modello Text-to-Video Avanzato

July 21, 2024

Il modello Stable Diffusion XL Video Pipeline è un'estensione del modello text-to-image, progettata per generare video infiniti utilizzando i frame precedenti per condizionare la diffusione dei frame successivi. Questo approccio consente di mantenere la coerenza temporale nei video generati.

Componenti Principali

StableDiffusionXLVideoPipeline: La classe principale che gestisce l'intero processo di generazione del video. Temporal Transformer: Utilizzato per catturare le dipendenze temporali tra i frame. Temporal Conditioner Transformer: Effettua il cross-attention con i frame passati per assicurare la coerenza temporale.

Architettura del Modello

StableDiffusionXLVideoPipeline

La classe StableDiffusionXLVideoPipeline gestisce l'intero processo di generazione, dall'encoding del prompt alla produzione dei frame del video. Ecco una panoramica delle principali funzioni e componenti della classe:

Inizializzazione del Modulo

La funzione init registra i moduli necessari, tra cui il VAE, i text encoder, i tokenizer, il modello UNet 3D e il scheduler.

class StableDiffusionXLVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
    def **init**(self, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, scheduler):
        super().**init**()
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            unet=unet,
            scheduler=scheduler
        )
        self.vae_scale_factor = 2 \*\* (len(self.vae.config.block_out_channels) - 1)
        self.default_sample_size = self.unet.config.sample_size

Encoding del Prompt

La funzione encode_prompt gestisce l'encoding del prompt di testo, inclusi i prompt negativi, se presenti. Utilizza i tokenizer e i text encoder per convertire il testo in embedding utilizzabili dal modello.

def encode_prompt(self, prompt, device=None, do_classifier_free_guidance=True, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None):
    tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
    text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
    prompt_embeds_list = []

    for tokenizer, text_encoder in zip(tokenizers, text_encoders):
        text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        text_input_ids = text_inputs.input_ids.to(device)
        prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
        prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)

    if do_classifier_free_guidance and negative_prompt is not None:
        negative_prompt_embeds_list = []
        for tokenizer, text_encoder in zip(tokenizers, text_encoders):
            negative_inputs = tokenizer(negative_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
            negative_input_ids = negative_inputs.input_ids.to(device)
            negative_prompt_embeds = text_encoder(negative_input_ids, output_hidden_states=True).hidden_states[-2]
            negative_prompt_embeds_list.append(negative_prompt_embeds)

        negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=-1)
    else:
        negative_prompt_embeds = torch.zeros_like(prompt_embeds)

    return prompt_embeds, negative_prompt_embeds

Preparazione dei Latents

La funzione prepare_latents prepara i latents per il processo di diffusione, gestendo la generazione di rumore iniziale se necessario.

def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
    shape = (batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor)
    if latents is None:
        latents = torch.randn(shape, device=device, dtype=dtype)
    else:
        latents = latents.to(device)
        latents = latents \* self.scheduler.init_noise_sigma
        return latents

Funzione di Chiamata Principale

La funzione call gestisce il processo di generazione del video, includendo la codifica del prompt, la preparazione dei latents e il loop di diffusione.

@torch.no_grad()
def **call**(self, prompt, negative_prompt=None, height=None, width=None, latent_image=None, device=None, num_frames=16, num_inference_steps=50, guidance_scale=9.0, conditioning_strength=0.0, eta=0.0, generator=None, latents=None, previous_latents=None):
    do_classifier_free_guidance = guidance_scale > 1.0
    prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, device, do_classifier_free_guidance, negative_prompt)

    if latents is None:
        latents = self.prepare_latents(
            batch_size=1,
            num_channels_latents=self.unet.config.in_channels,
            num_frames=num_frames,
            height=height,
            width=width,
            dtype=prompt_embeds.dtype,
            device=device,
            generator=generator
        )

    for t in self.scheduler.timesteps:
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            conditioning_hidden_states=previous_latents
        )

        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

    return latents

Temporal Transformer

Il Temporal Transformer cattura le dipendenze temporali tra i frame, permettendo al modello di mantenere la coerenza temporale nel video generato. Utilizza embeddings sinusoidali per codificare la posizione temporale dei frame.

class TemporalTransformer(ModelMixin, ConfigMixin):
@register*to_config
    def **init**(self, num_attention_heads=16, attention_head_dim=88, in_channels=None, out_channels=None, num_layers=1, dropout=0.0):
        super().**init**()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads \* attention_head_dim
        self.in_channels = in_channels
        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        self.proj_in = nn.Linear(in_channels, inner_dim)
        self.positional_encoding = SinusoidalEmbeddings()
        self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, dropout=dropout) for * in range(num_layers)])
        self.proj_out = nn.Linear(inner_dim, in_channels)

        def forward(self, hidden_states, return_dict=True):
            if hidden_states.size(2) <= 1:
                if not return_dict:
                    return (hidden_states,)
                return TemporalTransformerOutput(sample=hidden_states)

            batch, channel, frames, height, width = hidden_states.shape
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
            hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch * frames, channel, height, width)
            hidden_states = self.proj_in(hidden_states)
            hidden_states = hidden_states + self.positional_encoding(hidden_states)

            for block in self.transformer_blocks:
                hidden_states = block(hidden_states=hidden_states)

            hidden_states = self.proj_out(hidden_states).view(batch, frames, channel, height, width).permute(0, 2, 1, 3, 4)
            hidden_states += residual

            return TemporalTransformerOutput(sample=hidden_states)

Temporal Conditioner Transformer

Il Temporal Conditioner Transformer utilizza il cross-attention con i frame passati per assicurare che ogni nuovo frame sia coerente con i precedenti, mantenendo una transizione fluida e naturale nel video. Questo modulo è cruciale per garantire la coerenza temporale nei video di lunga durata.

class TemporalConditionerTransformer(ModelMixin, ConfigMixin):
@register*to_config
    def **init**(self, num_attention_heads=16, attention_head_dim=88, in_channels=None, cross_attention_dim=2560, num_layers=1, only_cross_attention=True, dropout=0.0):
        super().**init**()
        self.conv_in = nn.Conv3d(4, cross_attention_dim, kernel_size=(1, 1, 1))
        self.ln = nn.LayerNorm(in_channels)
        self.proj_in = nn.Linear(in_channels, cross_attention_dim)
        self.positional_encoding = ExponentialEmbeddings()
        self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(cross_attention_dim, num_attention_heads, attention_head_dim, dropout=dropout) for * in range(num_layers)])
        self.proj_out = nn.Linear(cross_attention_dim, in_channels)

    def forward(self, hidden_states, encoder_hidden_states, return_dict=True):
        h_b, h_c, h_f, h_h, h_w = hidden_states.shape
        residual = hidden_states
        encoder_hidden_states = torch.nn.functional.interpolate(encoder_hidden_states, size=(encoder_hidden_states.shape[2], h_h, h_w), mode='trilinear', align_corners=False).repeat_interleave(repeats=h_b, dim=0)
        encoder_hidden_states = self.conv_in(encoder_hidden_states).permute(0, 2, 1, 3, 4).reshape(h_b * h_h * h_w, h_f, encoder_hidden_states.shape[1])
        hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(h_b * h_h * h_w, h_f, h_c)
        hidden_states = self.ln(hidden_states)
        hidden_states = self.proj_in(hidden_states)
        encoder_hidden_states += self.positional_encoding(encoder_hidden_states)
        hidden_states += self.positional_encoding(hidden_states, reverse=True)

        for block in self.transformer_blocks:
            hidden_states = block(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states)

        hidden_states = self.proj_out(hidden_states).view(h_b, h_h, h_w, h_f, h_c).permute(0, 4, 3, 1, 2)
        hidden_states += residual

        return TemporalConditionerTransformerOutput(sample=hidden_states)

Esempio di Utilizzo

Ecco un esempio di come utilizzare il modello Stable Diffusion XL Video Pipeline per generare un video a partire da un prompt testuale.

import torch

def generate_video():
    model_path = "path/to/pretrained/model"
    pipeline = StableDiffusionXLVideoPipeline.from_pretrained(pretrained_model_name_or_path=model_path).to("cuda")
    prompt = "Un bellissimo tramonto sul mare"
    video = pipeline(prompt=prompt, num_frames=24, num_inference_steps=50, guidance_scale=9.0)
    save_video(video, "output_video.mp4")

def save_video(video_tensor, filename): # Funzione per salvare il video su file # ...

if **name** == "**main**":
    generate_video()

Preparazione del Pipeline

La funzione initialize_pipeline carica il modello pre-addestrato e imposta il DPMSolverMultistepScheduler con i parametri specificati per garantire un'inferenza efficiente.

def initialize_pipeline(model, device="cuda"):
    pipeline = StableDiffusionXLVideoPipeline.from_pretrained(pretrained_model_name_or_path=model).to(device=device)
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
    pipeline.scheduler.config,
    timestep_spacing="trailing",
    beta_schedule="scaled_linear",
    beta_start=0.00082,
    beta_end=0.014,
    algorithm_type="sde-dpmsolver++"
)
    pipeline.vae.enable_slicing()
    return pipeline

Inferenza

La funzione inference esegue l'intero processo di generazione del video. Utilizza il pipeline per creare i frame, li normalizza e li concatena per formare un video coerente. Successivamente, salva il video generato in un file.

@torch.inference_mode()
def inference(
        pretrained_model_path,
        prompt,
        negative_prompt=None,
        width=768,
        height=768,
        init_image=None,
        use_init_image=False,
        num_frames=4,
        num_conditioning_frames=4,
        conditioning_strength=0.0,
        num_repeats=4,
        fps=6,
        num_inference_steps=50,
        guidance_scale=15,
        device="cuda",
        output_folder="videos",
        seed=None
    ):

    if seed is not None:
    set_seed(seed)
    os.makedirs(output_folder, exist_ok=True)
    with torch.autocast(device, dtype=torch.float16):
        pipeline = initialize_pipeline(pretrained_model_path, device)
        latent_image = None
        if init_image is not None and use_init_image:
            image = Image.open(init_image).convert('RGB')
            transform = transforms.Compose([
                transforms.Resize((height, width)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            image = transform(image).to(device=device).unsqueeze(0).unsqueeze(0).permute(0, 2, 1, 3, 4)
            latent_image = encode(image, pipeline.vae).repeat_interleave(repeats=num_frames, dim=2)

        out_file = f"{output_folder}/{prompt}.mp4"
        encoded_out_file = f"{output_folder}/{prompt}_encoded.mp4"
        previous_latents, stacked_latents = None, None

        for i in range(num_repeats + 1):
            with torch.no_grad():
                latents = pipeline(
                    prompt,
                    negative_prompt=negative_prompt,
                    device=device,
                    width=width,
                    height=height,
                    latent_image=latent_image,
                    conditioning_strength=conditioning_strength,
                    num_frames=num_frames,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    previous_latents=previous_latents
                )
            stacked_latents = torch.cat((stacked_latents, latents), dim=2) if i > 0 else latents
            stacked_latents = normalize_latents(stacked_latents)
            previous_latents = stacked_latents[:, :, -num_conditioning_frames:, :, :]

        stacked_latents = stacked_latents[:, :, num_frames:, :, :]
        tensor = decode(stacked_latents, pipeline.vae)
        save_video(tensor, out_file, fps)

        try:
            encode_video(out_file, encoded_out_file, get_video_height(out_file))
            os.remove(out_file)
        except:
            pass

Codifica e Decodifica dei Latents

Funzioni encode e decode per convertire i tensor in latents e viceversa utilizzando il VAE.

def encode(tensor, vae):
    tensor = tensor.float().permute(0, 2, 1, 3, 4).reshape(tensor.shape[0] _ tensor.shape[2], tensor.shape[1], tensor.shape[3], tensor.shape[4])
    latents = vae.encode(tensor).latent_dist.sample() _ vae.config.scaling_factor
    return latents.view(tensor.shape[0], -1, tensor.shape[1], tensor.shape[3], tensor.shape[4]).permute(0, 2, 1, 3, 4)

def decode(latents, vae):
    latents = (1 / vae.config.scaling_factor) _ latents.permute(0, 2, 1, 3, 4).reshape(latents.shape[0] _ latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4])
    image = vae.decode(latents).sample
    return image.view(latents.shape[0] // latents.shape[2], latents.shape[2], -1, image.shape[2], image.shape[3]).permute(0, 2, 1, 3, 4).float()

Salvataggio del Video

La funzione save_video normalizza e salva il video generato su file.

def save_video(normalized_tensor, output_path, fps=30):
    denormalized_frames = denormalize(normalized_tensor)
    height, width = denormalized_frames.shape[1:3]
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(\*'mp4v'), fps, (width, height))
    for frame in denormalized_frames:
    out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    out.release()

Normalizzazione e Denormalizzazione

Funzioni di utilità per normalizzare e denormalizzare i tensor.

def normalize_latents(latents):
    return (latents - latents.mean()) / (latents.std() + 1e-8)

def denormalize(normalized_tensor):
    denormalized = (normalized_tensor + 1.0) \* 127.5
    return torch.clamp(denormalized, 0, 255).to(torch.uint8).permute(1, 2, 3, 0).numpy()

Impostazione del Seed

Funzione per impostare il seed per la riproducibilità dei risultati.

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Conclusione

Il modello Stable Diffusion XL Video Pipeline rappresenta un significativo passo avanti nella generazione di video da testo, utilizzando tecniche avanzate come i trasformatori temporali e condizionali. Questo approccio permette di creare video di lunga durata e coerenti, aprendo nuove possibilità per la creatività e l'applicazione dell'intelligenza artificiale.

Per ulteriori dettagli e per esplorare il codice completo, puoi consultare il repository del progetto.