Viewing a single comment thread. View all comments

HoLeeFaak t1_irvnlrf wrote

That's a pretty hard problem, because text generation involve argmax/sampling which is not differentiable, so it's hard to optimize a model to generate text that will then be inserted as input to a text2img model to generate a given image. I guess you could do something similar to replacing CLIP with Stable Diffusion, changing the objective a bit, but I think it will be hard to optimize.


MohamedRashad OP t1_irvolp8 wrote

I thought about self-supervision for this task. Enter the image I want it's prompt to an Image-to-text model and the resulting text I feed to a diffusion model (DALL-E, Stable Diffusion) which I freeze their weights so they don't change.

The output image will be compared to the original image I entered and the loss will be backpropagated to the image-to-text model to learn. The problems with this approach (in my humble opinion) are two:

  1. Training such system won't be easy and I will need a lot of resources I currently don't have.
  2. And even if I succeed The resulting model won't be good enough for generalization.

This is of course if I managed to overcome the non-differentiable parts.


HoLeeFaak t1_irvoxe5 wrote

What you propose is a cycle-loss. It's valid, but the biggest problem is the non-differentiable parts, and this is a big problem that I didn't find a solution to.


samb-t t1_irvsicm wrote

If you have enough resources to train an autoregressive model then you could take advantage of knowing that these big text-to-image models are conditioned on CLIP embeddings and instead train an autoregressive model to predict prompts conditioned on CLIP image embeddings. That way there's no non-differentiable parts to bypass and the CLIP embeddings should be a pretty great descriptor of both the input image and the prompt.

If you don't have enough resources then (just thinking out loud, probably be a better way but might give some ideas) you could again use a pretrained CLIP model. 1. Embed the input image. 2. Using the CLIP text embedding network optimise the input text to get an embedding close to the image embedding. Problem there is again that text is discrete so you can't backprop. You could use gumbel softmax to approximate the discrete text values though (anneal down how continuous it is). Alternatively you could treat the embedding distance loss as an energy function, and use discrete MCMC, something like gibbs-with-gradients. But both of those options still probably aren't great, it's a horrible optimisation space