LaVin-DiT

LaVin-DiT: Large Vision Diffusion Transformer[1]

作者是来自悉尼大学、NUS等机构的Zhaoqing Wang等人。论文引用[1]:Wang, Zhaoqing et al. “LaVin-DiT: Large Vision Diffusion Transformer.” ArXiv abs/2411.11505 (2024): n. pag.

Time

  • 2025.Mar

Key Words

  • Diffusion Transformer
  • 将ST-VAE 和Diffusion transformer结合起来,有效地处理高维vision data,通过in-context learning, LaVin-ViT能够适应多个tasks不需要fine-tuning。

总结

  1. 本文提出了Large Vision Diffusion Transformer(LaVin-DiT),是一个可扩展的、统一的foundation model,以generative框架的形式,处理超过20种cv tasks,不同于现有的large vision models,它们直接从NLP的架构修改而来,很少依赖于高效的自回归方式,扰乱了对于vision data很重要的spatial relationships。LaVin-DiT引入了key innovations来优化generative performance for CV tasks,首先:为了解决visual data的高维问题,作者引入了一个spatial-temporal variational autoencoder,将data编码到一个连续的latent space;其次,对于generative modeling,作者开发了一个joint diffusion transformer,能够progressively 产生vision outputs;第三,对于统一的多任务训练,执行in-context learning,input-target pairs作为task context,指导diffusion transformer在latent space中,将outputs和specific tasks进行对齐,在推理的时候,一个task-specific context set和test data作为queries,使得LaVin-DiT不需要fine-tune就能泛化到多个tasks,在大量的vision datasets上训练,这个model从0.1B扩展到3.4B,展示出了scalability和SOTA的性能

  2. LLMs例如GPT和LLaMA收到了广泛的关注,展示出了在一个统一的框架中处理多个language tasks的能力,这个将多个language tasks集成到single large model的突破,引发了为CV开发类似的large model的时刻。开发能够处理多个vision tasks的LVMs表示一个通往vision-based AI的更加versatile, scalable和高效的路径。然而,构建LVMs比LLMs更加的复杂,因为内在的diverse和vision data的高维本质,还有处理多个不同scale,perspective的需求。为了处理这个问题,最近的工作开发了一种sequential modeling方法,通过representing images、videos、annotations in a unified visual sentences format,来学习purely vision data。这个方法使得model来预测sequential vision tokens from a vast dataset,独立于language-based input,尽管这个方法在多个vision tasks上展示出了promising results,它面临两个挑战:**第一个问题是自回归sequence modeling的效率问题,它需要token-by-token的预测,对于高维数据是computational intensive,第二个问题是将vision data转换成sequential format的时候,spatial coherence的discruption,在对于vision tasks很重要的spatial dependencies上妥协了。

  3. 在本文中,作者引入了一个large visino diffusion transformer,来推动下一代LVMs的发展,LaVin-DiT 有更好的计算效率,有效地保留了vision data中的spatial relationships,在多个vision tasks上,是心啊了良好的performance。为了处理高维的vision data,作者引入了spatial-temporal variational autoencoder将data编码到一个latent space,能够用compact representation保留temporal和spatial features,这降低了计算需求,不需要牺牲模型捕捉complex patterns的能力,来提高效率。另外,对于生成modeling,作者增强了现有的diffusion transformer,提出了一个联合的diffusion transformer with full-sequences joint attention,这个module通过parallel denoising steps,综合了visuual outputs,有效地降低了sequential dependencies,增强了处理效率,同时保持了spatial coherence。另外,为了支持统一的多任务训练,作者引入了一个in-context,input-target pairs指导diffusion transformer,将outputs和specific tasks进行对齐。在推理的时候,LaVin-DiT利用task-specific context sets和test data作为queries,来适应多个tasks,而不需要微调,这使得LaVin-DiT实现了在多个任务上的泛化性。

  4. 开发一个针对多个任务的框架是deep learning的一个和藏起的目标,NLP已经用ChatGPT实现那了,相比之下,CV缺乏统一的framework,很大程度上是因为visual data和tasks的复杂性和多样性,现有的统一的vision frameworks有两个路径:image-resembling generation和sequential modeling。 image-resembling generation方法是将visual tasks表述为一个image generation 问题,使得model横沟通过inpainting和reconstruction来处理dense visual predictions,例如:Painter将dense prediction tasks转换为一个masked image inpainting,在多个vision tasks上,展示出了in-context 能力,通过利用预训练的diffusion models,一些方法利用visual 或者textual instruction来指导generation和增强adaptability,这个sequential modeling方法主要是受LLMs突破的启发,将sequence-to-sequence 框架用于visual data,对于这些方法,visual data被quantized为离散的tokens的sequences,这个model通过next-token prediction进行优化,最近,有人引入了一个framework,将这个concept扩展到vision,不需要依赖于语言数据,将visual data视为visual sequences,通过将images和videos视为one-dimensional sequences,这个方法使得unified transformer能够处理image和video tasks within 一个单个framework,扩大了sequential modeling在CV中的scope。

    在本文中,从image-resembling generation的角度,作者提出了一个universal diffusion framework,用一个对visual data特别处理的transformer架构,保留了spatial-temporal structure,最小化了信息的丢失,在大量的visual data上的实验表明,作者的framework,统一了image和video tasks,推动了CV中的generalist model

  5. 借助于ViT,最近在generative modeling中的进展,在scalability和performace上都取得了很大的进展,U-ViT将所有的input视为tokens,通过将Transformer block和U-Net架构结合,展示出了scalability和diffusion transfomer的versatility,MDT和MaskDiT通过用一个masking策略,增强了DiT的训练效率,另外,Stable Diffusion 3引入了新的transformer-based架构,用于text-to-image generation,使得image和text的交互成为可能。另外,DiT在video generation中,展示出了robust的spatial-temporal modeling的能力,之前的方法利用分开的spatial和temporal attention机制降低intensive computational cost,另外,最近的工作提出用3D full attention来捕捉spatial-temporal 信息,确保large-moving objects的一致性,同时,diffusion transformer在visual content genration上展示出了潜力,它们作为一个large vision model,统一多个vision tasks的能力还没有被探索,本文中,作者引入了一个新的joint diffusion transformer,带有full-sequences joint attention,有效地集成了多个vision tasks,将diffusion transformer提高到了一个新的level。

  6. In-context learning最开始是由GPT-3概念化的,它革新了task-specific model training,通过允许model直接基于contextual examples进行推理和执行任务,这个范式变化使得model不需要在这些specific tasks上进行训练,就能执行complex reasoning和novel pattern recognition。在文本之外,Flamingo引入了visual inputs,将in-context learning扩展到multi-modal tasks例如image captioning, visual question answering,这展示了model的能力,能够理解textual和visual data,在多个不同的domains增强了它的应用。在CV中,这个in-context learning的概念通过visual prompting被explored,直接从concatenated image examples和queries中infer tasks,在本文中,作者基于这个idea,一组examples被sample,作为task definitions,然后和input query进行concantenated for the model,来得都啊predictions。

  7. CV包括一系列的任务,例如Object detection等,这些通常被specialized models进行处理,这个specialization限制了model的adaptability和scalability across multiple tasks,为了克服这个Limitation,作者设计了一个conditional generative framework,用单个cohesive model统一了多个vision tasks,特别地,给定query x,这个framework产生对应的prediction \(\hat{y}\),来近似在input-target pairs \(s\) 条件下的 y,这个conditional pairs提供了task definitions和guidance,使得model能够灵活地适应多个tasks,这个objective 就是建模 conditional distribution \(p(y|x, s)\)

  8. 如图所示,提出的Large Vision Diffusion Transformer(LaVin-DiT) 框架,集成了spatial-temporal variational autoencoder,,用一个joint diffusion transformer来统一多个vision tasks,给定一个vision task,例如全景分割,首先采样一组input-target pairs作为tasks definitions,之后,这个set和其它的visual examples给到ST-VAE,被编码成latent representations,之后,编码的representations进行patchify,展开成一个sequential format,这个set和input visual data构成了conditional latent presentation \(z_c\),target被随机的高斯噪声进行perturbe,产生了noisy latent representation \(z_t\)\(z_c\)\(z_t\) 给到joint diffusion transformer,对 \(z_t\) 进行去噪,恢复clean latent representation with shared latent space,最后,恢复的latent representation通过ST-VAE decoder在raw pixel space重建target

  9. ST-VAE能够高效地压缩spatial和temporal information,将它们从pixel space编码到compact latent space。ST-VAE用了causal 3D conv和deconv来压缩和recontruct visual data,整体上包括一个encoder,一个decoder和一个latent regularization layer,这些components被组织成4个对称的stages,交替进行downsampling和upsampling,前两个stages对spatial和temporal dimensions进行操作,同时last stage只影响spaital dimension,是心啊了有效的 \(4 \times 8 \times 8\) 的压缩,降低了计算负担,另外,作者应用了KL constraint来正则化Gaussian Latent space。为了防止信息泄露和它对temporal predictions的反面影响,在temporal conv space的start,对所有的locations进行pad,另外,为了支持image和video processing,将输入video的第一帧进行单独处理,只对它进行spatially compress,保持temporal independence。随后的frames在spatial和temporal dimensions进行压缩,ST-VAE的encoder将input压缩到一个低维的latent space,通过一个decoding实现reconstructino。用两阶段训练ST-VAE:首先单独训练images,然后联合images和videos,在每个stage,作者用mean squared error, perceptual loss和adversarial loss的结合来优化model。

  10. J-DiT:Diffusion transformers 是一个powerful 方法用于generative modeling,作者的diffusion transformer建立在DiT上,但是做了修正, 来支持task-conditioned generation,和原始DiT的不同是两个不同的latent representations,这个condition latent representation 是clean,同时target latent representation是被高斯噪声扰动了,导致潜在distinct value ranges for two。为了处理task-specific和visual information之间的不同和提高alignment,作者对于condition和target latents sperate的patch embeddings,每个embedding layer拥有给patch size \(2 \times 2\),使得对于每个latent type,对representations进行定制。这个采样的timestep t,还有condition和target sequences,给到diffusion transformer layers,建立在MM-DiT架构上,作者引入了condition- 和target-specific adaptive RMS normalization(AdaRN) 来独立地modulate每个representation space,这是通过在 AdaRN 层中对条件和目标使用不同的timestep embeddings来实现的。

  11. Full-sequence joint attention是 transformer layers的key,一起处理condition和noisy target sequences,来增强task-specific alignment,这个condition和target sequences进行线性projected,concatenated,然后被双向attention module进行处理,从而允许每个序列在各自的独立space中运作,同时又能相互参考。为了提高速度和memory的效率,作者用grouped-query attention代替MHA,该方法通过将查询头(query heads)进行分组,使每组共享一组键值头(key-value heads)。这个方法降低了参数同时保持了expressiveness,性能接近标准的多头注意力机制,另外,为了稳定训练larger models和longer sequences,在query-key dot products之前,加入了QK-Norm,来控制attention entropy growth,作者也应用了sandwich normalization,在每个attention和FFN layer之后,来保持activation magnitudes。

  12. 作者argue:将visual data建模为one-dimensional sequences是sub-optimal,因为1D positional embedding在捕捉精确的spatial-temporal positions上是有限的,相反,将多个image-annotation pairs或者video clips作为single continuous sequence,作者用3D Rotary Position Encoding(3D RoPE) 来表示spatial-temporal relationships,然后,视频中的每个location能够被一个3D 坐标所表达,有了3D RoPE,作者提供一个统一的、精确的spatial-temporal representation of positional encoding,用于多种vision tasks。

  13. J-DiT的训练过程:作者用flow matching训练J-DiT,特别地,给定一个representation \(z_0\) 和noise \(z_1 ~ N(0,1)\), flow matching基于前向过程定义一个线性插值,\(z_t = tz_0 + (1-t)z_1\),timestep \(t \in [0,1]\),这个前向过程引出了time-dependent velocity field \(v(z_t,t)\),沿着 \(z_0-z_1\) 的方向drive flow,这个velocity field定义了一个oridnary differential equation(ODE): \(dz_t = v(z_t,t_dt)\),作者采用了J-DiT,被 \(\theta\) 进行参数化,来预测velocity field,将noise转换成clean latent representation,这个flow matching的训练目标是直接回归velocity field,得到Conditional Flow Matching(CFM) loss。

  14. 在完成J-DiT training之后,我们利用它通过将噪声分布向representation分布进行整合,从而生成新的表征。特别地,从noise $z_1 ~ N(0,1), $ at t=1,作者将学习到的J-DiT backward to t=0,来得到表征 $z_0$,例如,用欧拉方法,作者将time interval[0,1] 离散成 N steps,一个negative steps size \(\deltat = -1/N\),表示backward integration in time,在每个step k = 0, 1/N,..., (N-1)/N,作者用下列的方法更新time和generated representation: \[ \begin{aligned} t^{(k+1/N)} &= t^{(k)} + \Delta t, \\ \mathbf{z}^{(k+1/N)} &= \mathbf{z}^{(k)} + v_\theta(\mathbf{z}^{(k)}, t^{(k)}) \Delta t, \end{aligned} \] \(t^{(0)}=1, t^{(1)}=0, z^{(0)}=z`, z^{(1)}=z`_0\),通过iteratively 应用这些updates,作者得到了一个新的presentation,用于following decoding process。

  15. LaVin-DiT:在完成LaVin-DiT的训练之后,这个model变得versatile,能够用于多个下游任务,特别地,当给定一个query(image或者video)for any chosen task,作者随机采样一组input-target pairs,定义这个task,这些pairs,和visual input,一个高斯噪声component一起,给到J-DiT,在J-DiT内,这些elements被处理,产生一个latent representation,最后,这个latent representation通过ST-VAE decoder,将其转换成raw pixel space,来产生desired prediction。

Overview \(Fig.1^{[1]}\). model开始将visual data从pixel space压缩到latent space,多个input-target pairs作为task context,一个target通过diffusion process被高斯噪声所扰动,受task context和query的guide,J-DiT通过N timesteps来iteratively denoise 则会个noisy target,来恢复一个clean latent representation,这个prediction然后通过ST-VAE decoder来生成。b和c提供了ST-VAE和J-DiT的细节