summaryrefslogtreecommitdiff
path: root/gen
diff options
context:
space:
mode:
Diffstat (limited to 'gen')
-rwxr-xr-xgen133
1 files changed, 133 insertions, 0 deletions
diff --git a/gen b/gen
new file mode 100755
index 0000000..2f9d7ab
--- /dev/null
+++ b/gen
@@ -0,0 +1,133 @@
+#!/bin/python
+
+from sys import argv
+import torch
+import numpy
+import imageio
+from CLIP import clip
+
+from lib import (
+ load_vqgan_model,
+ MakeCutouts,
+ parse_prompt,
+ Prompt,
+ vector_quantize,
+ clamp_with_grad,
+)
+from torch.nn import functional
+from torch import optim
+from torchvision import transforms
+
+magic = " ".join(argv[1:])
+width = 300
+height = 300
+seed = None
+cycles = 100
+
+magic = [frase.strip() for frase in magic.split("|")]
+if magic == [""]:
+ magic = []
+
+noise_prompt_seeds = []
+noise_prompt_weights = []
+size = [width, height]
+init_weight = 0.0
+clip_model = "ViT-B/32"
+vqgan_config = f"vqgan_imagenet_f16_16384.yaml"
+vqgan_checkpoint = f"vqgan_imagenet_f16_16384.ckpt"
+step_size = 0.1
+cutn = 64
+cut_pow = 1.0
+seed = seed
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print("Using device:", device)
+seed = torch.seed()
+torch.manual_seed(seed)
+print("Using seed:", seed)
+
+model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(device)
+perceptor = clip.load(clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
+
+cut_size = perceptor.visual.input_resolution
+e_dim = model.quantize.e_dim
+f = 2 ** (model.decoder.num_resolutions - 1)
+make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow)
+n_toks = model.quantize.n_e
+toksX, toksY = size[0] // f, size[1] // f
+sideX, sideY = toksX * f, toksY * f
+z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
+z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
+
+one_hot = functional.one_hot(
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
+).float()
+z = one_hot @ model.quantize.embedding.weight
+z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
+
+z_orig = z.clone()
+z.requires_grad_(True)
+opt = optim.Adam([z], lr=step_size)
+
+normalize = transforms.Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
+)
+
+pMs = []
+
+for prompt in magic:
+ txt, weight, stop = parse_prompt(prompt)
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
+ pMs.append(Prompt(embed, weight, stop).to(device))
+
+for seed, weight in zip(noise_prompt_seeds, noise_prompt_weights):
+ gen = torch.Generator().manual_seed(seed)
+ embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
+ pMs.append(Prompt(embed, weight).to(device))
+
+
+def synth(z):
+ z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(
+ 3, 1
+ )
+ return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
+
+
+def ascend_txt():
+ global i
+ out = synth(z)
+ iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
+
+ result = []
+
+ if init_weight:
+ result.append(functional.mse_loss(z, z_orig) * init_weight / 2)
+
+ for prompt in pMs:
+ result.append(prompt(iii))
+ img = numpy.array(
+ out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(numpy.uint8)
+ )[:, :, :]
+ img = numpy.transpose(img, (1, 2, 0))
+ filename = f"steps/{i:04}.png"
+ imageio.imwrite(filename, numpy.array(img))
+ return result
+
+
+def train(i):
+ opt.zero_grad()
+ lossAll = ascend_txt()
+ loss = sum(lossAll)
+ loss.backward()
+ opt.step()
+ with torch.no_grad():
+ z.copy_(z.maximum(z_min).minimum(z_max))
+
+
+i = 0
+while True:
+ train(i)
+ if i == cycles:
+ break
+ i += 1
+ print(str(i) + " out of " + str(cycles))