Il modello Stable Diffusion XL Video Pipeline è un'estensione del modello text-to-image, progettata per generare video infiniti utilizzando i frame precedenti per condizionare la diffusione dei frame successivi. Questo approccio consente di mantenere la coerenza temporale nei video generati.
Componenti Principali
StableDiffusionXLVideoPipeline: La classe principale che gestisce l'intero processo di generazione del video. Temporal Transformer: Utilizzato per catturare le dipendenze temporali tra i frame. Temporal Conditioner Transformer: Effettua il cross-attention con i frame passati per assicurare la coerenza temporale.
Architettura del Modello
StableDiffusionXLVideoPipeline
La classe StableDiffusionXLVideoPipeline gestisce l'intero processo di generazione, dall'encoding del prompt alla produzione dei frame del video. Ecco una panoramica delle principali funzioni e componenti della classe:
Inizializzazione del Modulo
La funzione init registra i moduli necessari, tra cui il VAE, i text encoder, i tokenizer, il modello UNet 3D e il scheduler.
class StableDiffusionXLVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
def **init**(self, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, scheduler):
super().**init**()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler
)
self.vae_scale_factor = 2 \*\* (len(self.vae.config.block_out_channels) - 1)
self.default_sample_size = self.unet.config.sample_size
Encoding del Prompt
La funzione encode_prompt gestisce l'encoding del prompt di testo, inclusi i prompt negativi, se presenti. Utilizza i tokenizer e i text encoder per convertire il testo in embedding utilizzabili dal modello.
def encode_prompt(self, prompt, device=None, do_classifier_free_guidance=True, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None):
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_input_ids = text_inputs.input_ids.to(device)
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
if do_classifier_free_guidance and negative_prompt is not None:
negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
negative_inputs = tokenizer(negative_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
negative_input_ids = negative_inputs.input_ids.to(device)
negative_prompt_embeds = text_encoder(negative_input_ids, output_hidden_states=True).hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=-1)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
return prompt_embeds, negative_prompt_embeds
Preparazione dei Latents
La funzione prepare_latents prepara i latents per il processo di diffusione, gestendo la generazione di rumore iniziale se necessario.
def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
latents = torch.randn(shape, device=device, dtype=dtype)
else:
latents = latents.to(device)
latents = latents \* self.scheduler.init_noise_sigma
return latents
Funzione di Chiamata Principale
La funzione call gestisce il processo di generazione del video, includendo la codifica del prompt, la preparazione dei latents e il loop di diffusione.
@torch.no_grad()
def **call**(self, prompt, negative_prompt=None, height=None, width=None, latent_image=None, device=None, num_frames=16, num_inference_steps=50, guidance_scale=9.0, conditioning_strength=0.0, eta=0.0, generator=None, latents=None, previous_latents=None):
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, device, do_classifier_free_guidance, negative_prompt)
if latents is None:
latents = self.prepare_latents(
batch_size=1,
num_channels_latents=self.unet.config.in_channels,
num_frames=num_frames,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator
)
for t in self.scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
conditioning_hidden_states=previous_latents
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
return latents
Temporal Transformer
Il Temporal Transformer cattura le dipendenze temporali tra i frame, permettendo al modello di mantenere la coerenza temporale nel video generato. Utilizza embeddings sinusoidali per codificare la posizione temporale dei frame.
class TemporalTransformer(ModelMixin, ConfigMixin):
@register*to_config
def **init**(self, num_attention_heads=16, attention_head_dim=88, in_channels=None, out_channels=None, num_layers=1, dropout=0.0):
super().**init**()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads \* attention_head_dim
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.positional_encoding = SinusoidalEmbeddings()
self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, dropout=dropout) for * in range(num_layers)])
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, return_dict=True):
if hidden_states.size(2) <= 1:
if not return_dict:
return (hidden_states,)
return TemporalTransformerOutput(sample=hidden_states)
batch, channel, frames, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch * frames, channel, height, width)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states + self.positional_encoding(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(hidden_states=hidden_states)
hidden_states = self.proj_out(hidden_states).view(batch, frames, channel, height, width).permute(0, 2, 1, 3, 4)
hidden_states += residual
return TemporalTransformerOutput(sample=hidden_states)
Temporal Conditioner Transformer
Il Temporal Conditioner Transformer utilizza il cross-attention con i frame passati per assicurare che ogni nuovo frame sia coerente con i precedenti, mantenendo una transizione fluida e naturale nel video. Questo modulo è cruciale per garantire la coerenza temporale nei video di lunga durata.
class TemporalConditionerTransformer(ModelMixin, ConfigMixin):
@register*to_config
def **init**(self, num_attention_heads=16, attention_head_dim=88, in_channels=None, cross_attention_dim=2560, num_layers=1, only_cross_attention=True, dropout=0.0):
super().**init**()
self.conv_in = nn.Conv3d(4, cross_attention_dim, kernel_size=(1, 1, 1))
self.ln = nn.LayerNorm(in_channels)
self.proj_in = nn.Linear(in_channels, cross_attention_dim)
self.positional_encoding = ExponentialEmbeddings()
self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(cross_attention_dim, num_attention_heads, attention_head_dim, dropout=dropout) for * in range(num_layers)])
self.proj_out = nn.Linear(cross_attention_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states, return_dict=True):
h_b, h_c, h_f, h_h, h_w = hidden_states.shape
residual = hidden_states
encoder_hidden_states = torch.nn.functional.interpolate(encoder_hidden_states, size=(encoder_hidden_states.shape[2], h_h, h_w), mode='trilinear', align_corners=False).repeat_interleave(repeats=h_b, dim=0)
encoder_hidden_states = self.conv_in(encoder_hidden_states).permute(0, 2, 1, 3, 4).reshape(h_b * h_h * h_w, h_f, encoder_hidden_states.shape[1])
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(h_b * h_h * h_w, h_f, h_c)
hidden_states = self.ln(hidden_states)
hidden_states = self.proj_in(hidden_states)
encoder_hidden_states += self.positional_encoding(encoder_hidden_states)
hidden_states += self.positional_encoding(hidden_states, reverse=True)
for block in self.transformer_blocks:
hidden_states = block(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = self.proj_out(hidden_states).view(h_b, h_h, h_w, h_f, h_c).permute(0, 4, 3, 1, 2)
hidden_states += residual
return TemporalConditionerTransformerOutput(sample=hidden_states)
Esempio di Utilizzo
Ecco un esempio di come utilizzare il modello Stable Diffusion XL Video Pipeline per generare un video a partire da un prompt testuale.
import torch
def generate_video():
model_path = "path/to/pretrained/model"
pipeline = StableDiffusionXLVideoPipeline.from_pretrained(pretrained_model_name_or_path=model_path).to("cuda")
prompt = "Un bellissimo tramonto sul mare"
video = pipeline(prompt=prompt, num_frames=24, num_inference_steps=50, guidance_scale=9.0)
save_video(video, "output_video.mp4")
def save_video(video_tensor, filename): # Funzione per salvare il video su file # ...
if **name** == "**main**":
generate_video()
Preparazione del Pipeline
La funzione initialize_pipeline carica il modello pre-addestrato e imposta il DPMSolverMultistepScheduler con i parametri specificati per garantire un'inferenza efficiente.
def initialize_pipeline(model, device="cuda"):
pipeline = StableDiffusionXLVideoPipeline.from_pretrained(pretrained_model_name_or_path=model).to(device=device)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config,
timestep_spacing="trailing",
beta_schedule="scaled_linear",
beta_start=0.00082,
beta_end=0.014,
algorithm_type="sde-dpmsolver++"
)
pipeline.vae.enable_slicing()
return pipeline
Inferenza
La funzione inference esegue l'intero processo di generazione del video. Utilizza il pipeline per creare i frame, li normalizza e li concatena per formare un video coerente. Successivamente, salva il video generato in un file.
@torch.inference_mode()
def inference(
pretrained_model_path,
prompt,
negative_prompt=None,
width=768,
height=768,
init_image=None,
use_init_image=False,
num_frames=4,
num_conditioning_frames=4,
conditioning_strength=0.0,
num_repeats=4,
fps=6,
num_inference_steps=50,
guidance_scale=15,
device="cuda",
output_folder="videos",
seed=None
):
if seed is not None:
set_seed(seed)
os.makedirs(output_folder, exist_ok=True)
with torch.autocast(device, dtype=torch.float16):
pipeline = initialize_pipeline(pretrained_model_path, device)
latent_image = None
if init_image is not None and use_init_image:
image = Image.open(init_image).convert('RGB')
transform = transforms.Compose([
transforms.Resize((height, width)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image).to(device=device).unsqueeze(0).unsqueeze(0).permute(0, 2, 1, 3, 4)
latent_image = encode(image, pipeline.vae).repeat_interleave(repeats=num_frames, dim=2)
out_file = f"{output_folder}/{prompt}.mp4"
encoded_out_file = f"{output_folder}/{prompt}_encoded.mp4"
previous_latents, stacked_latents = None, None
for i in range(num_repeats + 1):
with torch.no_grad():
latents = pipeline(
prompt,
negative_prompt=negative_prompt,
device=device,
width=width,
height=height,
latent_image=latent_image,
conditioning_strength=conditioning_strength,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
previous_latents=previous_latents
)
stacked_latents = torch.cat((stacked_latents, latents), dim=2) if i > 0 else latents
stacked_latents = normalize_latents(stacked_latents)
previous_latents = stacked_latents[:, :, -num_conditioning_frames:, :, :]
stacked_latents = stacked_latents[:, :, num_frames:, :, :]
tensor = decode(stacked_latents, pipeline.vae)
save_video(tensor, out_file, fps)
try:
encode_video(out_file, encoded_out_file, get_video_height(out_file))
os.remove(out_file)
except:
pass
Codifica e Decodifica dei Latents
Funzioni encode e decode per convertire i tensor in latents e viceversa utilizzando il VAE.
def encode(tensor, vae):
tensor = tensor.float().permute(0, 2, 1, 3, 4).reshape(tensor.shape[0] _ tensor.shape[2], tensor.shape[1], tensor.shape[3], tensor.shape[4])
latents = vae.encode(tensor).latent_dist.sample() _ vae.config.scaling_factor
return latents.view(tensor.shape[0], -1, tensor.shape[1], tensor.shape[3], tensor.shape[4]).permute(0, 2, 1, 3, 4)
def decode(latents, vae):
latents = (1 / vae.config.scaling_factor) _ latents.permute(0, 2, 1, 3, 4).reshape(latents.shape[0] _ latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4])
image = vae.decode(latents).sample
return image.view(latents.shape[0] // latents.shape[2], latents.shape[2], -1, image.shape[2], image.shape[3]).permute(0, 2, 1, 3, 4).float()
Salvataggio del Video
La funzione save_video normalizza e salva il video generato su file.
def save_video(normalized_tensor, output_path, fps=30):
denormalized_frames = denormalize(normalized_tensor)
height, width = denormalized_frames.shape[1:3]
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(\*'mp4v'), fps, (width, height))
for frame in denormalized_frames:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
Normalizzazione e Denormalizzazione
Funzioni di utilità per normalizzare e denormalizzare i tensor.
def normalize_latents(latents):
return (latents - latents.mean()) / (latents.std() + 1e-8)
def denormalize(normalized_tensor):
denormalized = (normalized_tensor + 1.0) \* 127.5
return torch.clamp(denormalized, 0, 255).to(torch.uint8).permute(1, 2, 3, 0).numpy()
Impostazione del Seed
Funzione per impostare il seed per la riproducibilità dei risultati.
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Conclusione
Il modello Stable Diffusion XL Video Pipeline rappresenta un significativo passo avanti nella generazione di video da testo, utilizzando tecniche avanzate come i trasformatori temporali e condizionali. Questo approccio permette di creare video di lunga durata e coerenti, aprendo nuove possibilità per la creatività e l'applicazione dell'intelligenza artificiale.
Per ulteriori dettagli e per esplorare il codice completo, puoi consultare il repository del progetto.