diff options
Diffstat (limited to 'gen')
-rwxr-xr-x | gen | 133 |
1 files changed, 133 insertions, 0 deletions
@@ -0,0 +1,133 @@ +#!/bin/python + +from sys import argv +import torch +import numpy +import imageio +from CLIP import clip + +from lib import ( + load_vqgan_model, + MakeCutouts, + parse_prompt, + Prompt, + vector_quantize, + clamp_with_grad, +) +from torch.nn import functional +from torch import optim +from torchvision import transforms + +magic = " ".join(argv[1:]) +width = 300 +height = 300 +seed = None +cycles = 100 + +magic = [frase.strip() for frase in magic.split("|")] +if magic == [""]: + magic = [] + +noise_prompt_seeds = [] +noise_prompt_weights = [] +size = [width, height] +init_weight = 0.0 +clip_model = "ViT-B/32" +vqgan_config = f"vqgan_imagenet_f16_16384.yaml" +vqgan_checkpoint = f"vqgan_imagenet_f16_16384.ckpt" +step_size = 0.1 +cutn = 64 +cut_pow = 1.0 +seed = seed + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Using device:", device) +seed = torch.seed() +torch.manual_seed(seed) +print("Using seed:", seed) + +model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(device) +perceptor = clip.load(clip_model, jit=False)[0].eval().requires_grad_(False).to(device) + +cut_size = perceptor.visual.input_resolution +e_dim = model.quantize.e_dim +f = 2 ** (model.decoder.num_resolutions - 1) +make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow) +n_toks = model.quantize.n_e +toksX, toksY = size[0] // f, size[1] // f +sideX, sideY = toksX * f, toksY * f +z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] +z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] + +one_hot = functional.one_hot( + torch.randint(n_toks, [toksY * toksX], device=device), n_toks +).float() +z = one_hot @ model.quantize.embedding.weight +z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) + +z_orig = z.clone() +z.requires_grad_(True) +opt = optim.Adam([z], lr=step_size) + +normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] +) + +pMs = [] + +for prompt in magic: + txt, weight, stop = parse_prompt(prompt) + embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() + pMs.append(Prompt(embed, weight, stop).to(device)) + +for seed, weight in zip(noise_prompt_seeds, noise_prompt_weights): + gen = torch.Generator().manual_seed(seed) + embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) + pMs.append(Prompt(embed, weight).to(device)) + + +def synth(z): + z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim( + 3, 1 + ) + return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) + + +def ascend_txt(): + global i + out = synth(z) + iii = perceptor.encode_image(normalize(make_cutouts(out))).float() + + result = [] + + if init_weight: + result.append(functional.mse_loss(z, z_orig) * init_weight / 2) + + for prompt in pMs: + result.append(prompt(iii)) + img = numpy.array( + out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(numpy.uint8) + )[:, :, :] + img = numpy.transpose(img, (1, 2, 0)) + filename = f"steps/{i:04}.png" + imageio.imwrite(filename, numpy.array(img)) + return result + + +def train(i): + opt.zero_grad() + lossAll = ascend_txt() + loss = sum(lossAll) + loss.backward() + opt.step() + with torch.no_grad(): + z.copy_(z.maximum(z_min).minimum(z_max)) + + +i = 0 +while True: + train(i) + if i == cycles: + break + i += 1 + print(str(i) + " out of " + str(cycles)) |