diff options
author | Tom Barrett <tom@tombarrett.xyz> | 2021-07-05 17:08:51 +0200 |
---|---|---|
committer | Tom Barrett <tom@tombarrett.xyz> | 2021-07-05 17:08:51 +0200 |
commit | e729c74265e7f8a9db85b85579ccac3f9eab8bed (patch) | |
tree | bc036f4af2816645b8001f81fe5b910e00f34d80 |
starting out
-rw-r--r-- | .gitignore | 6 | ||||
-rwxr-xr-x | bootstrap | 21 | ||||
-rwxr-xr-x | gen | 133 | ||||
-rw-r--r-- | lib.py | 179 |
4 files changed, 339 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4e5e677 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +CLIP/ +__pycache__/ +taming-transformers/ +taming/ +vqgan_imagenet_f16_16384.ckpt +vqgan_imagenet_f16_16384.yaml diff --git a/bootstrap b/bootstrap new file mode 100755 index 0000000..2121798 --- /dev/null +++ b/bootstrap @@ -0,0 +1,21 @@ +#!/bin/bash +git clone https://github.com/openai/CLIP +git clone https://github.com/CompVis/taming-transformers + +pip install ftfy regex tqdm omegaconf pytorch-lightning +pip install kornia +pip install einops +pip install torchvision +pip install imageio + +pip install stegano +pip install python-xmp-toolkit +pip install imgtag +pip install pillow +pip install imageio-ffmpeg + +#apt install exempi +pacman -Sy emempi + +curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml' +curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt' @@ -0,0 +1,133 @@ +#!/bin/python + +from sys import argv +import torch +import numpy +import imageio +from CLIP import clip + +from lib import ( + load_vqgan_model, + MakeCutouts, + parse_prompt, + Prompt, + vector_quantize, + clamp_with_grad, +) +from torch.nn import functional +from torch import optim +from torchvision import transforms + +magic = " ".join(argv[1:]) +width = 300 +height = 300 +seed = None +cycles = 100 + +magic = [frase.strip() for frase in magic.split("|")] +if magic == [""]: + magic = [] + +noise_prompt_seeds = [] +noise_prompt_weights = [] +size = [width, height] +init_weight = 0.0 +clip_model = "ViT-B/32" +vqgan_config = f"vqgan_imagenet_f16_16384.yaml" +vqgan_checkpoint = f"vqgan_imagenet_f16_16384.ckpt" +step_size = 0.1 +cutn = 64 +cut_pow = 1.0 +seed = seed + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Using device:", device) +seed = torch.seed() +torch.manual_seed(seed) +print("Using seed:", seed) + +model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(device) +perceptor = clip.load(clip_model, jit=False)[0].eval().requires_grad_(False).to(device) + +cut_size = perceptor.visual.input_resolution +e_dim = model.quantize.e_dim +f = 2 ** (model.decoder.num_resolutions - 1) +make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow) +n_toks = model.quantize.n_e +toksX, toksY = size[0] // f, size[1] // f +sideX, sideY = toksX * f, toksY * f +z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] +z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] + +one_hot = functional.one_hot( + torch.randint(n_toks, [toksY * toksX], device=device), n_toks +).float() +z = one_hot @ model.quantize.embedding.weight +z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) + +z_orig = z.clone() +z.requires_grad_(True) +opt = optim.Adam([z], lr=step_size) + +normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] +) + +pMs = [] + +for prompt in magic: + txt, weight, stop = parse_prompt(prompt) + embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() + pMs.append(Prompt(embed, weight, stop).to(device)) + +for seed, weight in zip(noise_prompt_seeds, noise_prompt_weights): + gen = torch.Generator().manual_seed(seed) + embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) + pMs.append(Prompt(embed, weight).to(device)) + + +def synth(z): + z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim( + 3, 1 + ) + return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) + + +def ascend_txt(): + global i + out = synth(z) + iii = perceptor.encode_image(normalize(make_cutouts(out))).float() + + result = [] + + if init_weight: + result.append(functional.mse_loss(z, z_orig) * init_weight / 2) + + for prompt in pMs: + result.append(prompt(iii)) + img = numpy.array( + out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(numpy.uint8) + )[:, :, :] + img = numpy.transpose(img, (1, 2, 0)) + filename = f"steps/{i:04}.png" + imageio.imwrite(filename, numpy.array(img)) + return result + + +def train(i): + opt.zero_grad() + lossAll = ascend_txt() + loss = sum(lossAll) + loss.backward() + opt.step() + with torch.no_grad(): + z.copy_(z.maximum(z_min).minimum(z_max)) + + +i = 0 +while True: + train(i) + if i == cycles: + break + i += 1 + print(str(i) + " out of " + str(cycles)) @@ -0,0 +1,179 @@ +import math +from sys import path + +path.append("./taming-transformers") +from taming.models import cond_transformer, vqgan + +import torch +from torch import nn +from torch.nn import functional + +from omegaconf import OmegaConf +import kornia.augmentation as K + + +def sinc(x): + return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) + + +def lanczos(x, a): + cond = torch.logical_and(-a < x, x < a) + out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) + return out / out.sum() + + +def ramp(ratio, width): + n = math.ceil(width / ratio + 1) + out = torch.empty([n]) + cur = 0 + for i in range(out.shape[0]): + out[i] = cur + cur += ratio + return torch.cat([-out[1:].flip([0]), out])[1:-1] + + +def resample(input, size, align_corners=True): + n, c, h, w = input.shape + dh, dw = size + + input = input.view([n * c, 1, h, w]) + + if dh < h: + kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) + pad_h = (kernel_h.shape[0] - 1) // 2 + input = functional.pad(input, (0, 0, pad_h, pad_h), "reflect") + input = functional.conv2d(input, kernel_h[None, None, :, None]) + + if dw < w: + kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) + pad_w = (kernel_w.shape[0] - 1) // 2 + input = functional.pad(input, (pad_w, pad_w, 0, 0), "reflect") + input = functional.conv2d(input, kernel_w[None, None, None, :]) + + input = input.view([n, c, h, w]) + return functional.interpolate( + input, size, mode="bicubic", align_corners=align_corners + ) + + +class ReplaceGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, x_forward, x_backward): + ctx.shape = x_backward.shape + return x_forward + + @staticmethod + def backward(ctx, grad_in): + return None, grad_in.sum_to_size(ctx.shape) + + +replace_grad = ReplaceGrad.apply + + +class ClampWithGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, min, max): + ctx.min = min + ctx.max = max + ctx.save_for_backward(input) + return input.clamp(min, max) + + @staticmethod + def backward(ctx, grad_in): + (input,) = ctx.saved_tensors + return ( + grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), + None, + None, + ) + + +clamp_with_grad = ClampWithGrad.apply + + +def vector_quantize(x, codebook): + d = ( + x.pow(2).sum(dim=-1, keepdim=True) + + codebook.pow(2).sum(dim=1) + - 2 * x @ codebook.T + ) + indices = d.argmin(-1) + x_q = functional.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook + return replace_grad(x_q, x) + + +class Prompt(nn.Module): + def __init__(self, embed, weight=1.0, stop=float("-inf")): + super().__init__() + self.register_buffer("embed", embed) + self.register_buffer("weight", torch.as_tensor(weight)) + self.register_buffer("stop", torch.as_tensor(stop)) + + def forward(self, input): + input_normed = functional.normalize(input.unsqueeze(1), dim=2) + embed_normed = functional.normalize(self.embed.unsqueeze(0), dim=2) + dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) + dists = dists * self.weight.sign() + return ( + self.weight.abs() + * replace_grad(dists, torch.maximum(dists, self.stop)).mean() + ) + + +def parse_prompt(prompt): + vals = prompt.rsplit(":", 2) + vals = vals + ["", "1", "-inf"][len(vals) :] + return vals[0], float(vals[1]), float(vals[2]) + + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cutn, cut_pow=1.0): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.cut_pow = cut_pow + self.augs = nn.Sequential( + K.RandomHorizontalFlip(p=0.5), + # K.RandomSolarize(0.01, 0.01, p=0.7), + K.RandomSharpness(0.3, p=0.4), + K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), + K.RandomPerspective(0.2, p=0.4), + K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), + ) + self.noise_fac = 0.1 + + def forward(self, input): + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(self.cutn): + size = int( + torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) + batch = self.augs(torch.cat(cutouts, dim=0)) + if self.noise_fac: + facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) + batch = batch + facs * torch.randn_like(batch) + return batch + + +def load_vqgan_model(config_path, checkpoint_path): + config = OmegaConf.load(config_path) + if config.model.target == "taming.models.vqgan.VQModel": + model = vqgan.VQModel(**config.model.params) + model.eval().requires_grad_(False) + model.init_from_ckpt(checkpoint_path) + elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer": + parent_model = cond_transformer.Net2NetTransformer(**config.model.params) + parent_model.eval().requires_grad_(False) + parent_model.init_from_ckpt(checkpoint_path) + model = parent_model.first_stage_model + else: + raise ValueError(f"unknown model type: {config.model.target}") + del model.loss + return model |