summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Barrett <tom@tombarrett.xyz>2021-07-05 17:08:51 +0200
committerTom Barrett <tom@tombarrett.xyz>2021-07-05 17:08:51 +0200
commite729c74265e7f8a9db85b85579ccac3f9eab8bed (patch)
treebc036f4af2816645b8001f81fe5b910e00f34d80
starting out
-rw-r--r--.gitignore6
-rwxr-xr-xbootstrap21
-rwxr-xr-xgen133
-rw-r--r--lib.py179
4 files changed, 339 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..4e5e677
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+CLIP/
+__pycache__/
+taming-transformers/
+taming/
+vqgan_imagenet_f16_16384.ckpt
+vqgan_imagenet_f16_16384.yaml
diff --git a/bootstrap b/bootstrap
new file mode 100755
index 0000000..2121798
--- /dev/null
+++ b/bootstrap
@@ -0,0 +1,21 @@
+#!/bin/bash
+git clone https://github.com/openai/CLIP
+git clone https://github.com/CompVis/taming-transformers
+
+pip install ftfy regex tqdm omegaconf pytorch-lightning
+pip install kornia
+pip install einops
+pip install torchvision
+pip install imageio
+
+pip install stegano
+pip install python-xmp-toolkit
+pip install imgtag
+pip install pillow
+pip install imageio-ffmpeg
+
+#apt install exempi
+pacman -Sy emempi
+
+curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml'
+curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt'
diff --git a/gen b/gen
new file mode 100755
index 0000000..2f9d7ab
--- /dev/null
+++ b/gen
@@ -0,0 +1,133 @@
+#!/bin/python
+
+from sys import argv
+import torch
+import numpy
+import imageio
+from CLIP import clip
+
+from lib import (
+ load_vqgan_model,
+ MakeCutouts,
+ parse_prompt,
+ Prompt,
+ vector_quantize,
+ clamp_with_grad,
+)
+from torch.nn import functional
+from torch import optim
+from torchvision import transforms
+
+magic = " ".join(argv[1:])
+width = 300
+height = 300
+seed = None
+cycles = 100
+
+magic = [frase.strip() for frase in magic.split("|")]
+if magic == [""]:
+ magic = []
+
+noise_prompt_seeds = []
+noise_prompt_weights = []
+size = [width, height]
+init_weight = 0.0
+clip_model = "ViT-B/32"
+vqgan_config = f"vqgan_imagenet_f16_16384.yaml"
+vqgan_checkpoint = f"vqgan_imagenet_f16_16384.ckpt"
+step_size = 0.1
+cutn = 64
+cut_pow = 1.0
+seed = seed
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print("Using device:", device)
+seed = torch.seed()
+torch.manual_seed(seed)
+print("Using seed:", seed)
+
+model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(device)
+perceptor = clip.load(clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
+
+cut_size = perceptor.visual.input_resolution
+e_dim = model.quantize.e_dim
+f = 2 ** (model.decoder.num_resolutions - 1)
+make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow)
+n_toks = model.quantize.n_e
+toksX, toksY = size[0] // f, size[1] // f
+sideX, sideY = toksX * f, toksY * f
+z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
+z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
+
+one_hot = functional.one_hot(
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
+).float()
+z = one_hot @ model.quantize.embedding.weight
+z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
+
+z_orig = z.clone()
+z.requires_grad_(True)
+opt = optim.Adam([z], lr=step_size)
+
+normalize = transforms.Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
+)
+
+pMs = []
+
+for prompt in magic:
+ txt, weight, stop = parse_prompt(prompt)
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
+ pMs.append(Prompt(embed, weight, stop).to(device))
+
+for seed, weight in zip(noise_prompt_seeds, noise_prompt_weights):
+ gen = torch.Generator().manual_seed(seed)
+ embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
+ pMs.append(Prompt(embed, weight).to(device))
+
+
+def synth(z):
+ z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(
+ 3, 1
+ )
+ return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
+
+
+def ascend_txt():
+ global i
+ out = synth(z)
+ iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
+
+ result = []
+
+ if init_weight:
+ result.append(functional.mse_loss(z, z_orig) * init_weight / 2)
+
+ for prompt in pMs:
+ result.append(prompt(iii))
+ img = numpy.array(
+ out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(numpy.uint8)
+ )[:, :, :]
+ img = numpy.transpose(img, (1, 2, 0))
+ filename = f"steps/{i:04}.png"
+ imageio.imwrite(filename, numpy.array(img))
+ return result
+
+
+def train(i):
+ opt.zero_grad()
+ lossAll = ascend_txt()
+ loss = sum(lossAll)
+ loss.backward()
+ opt.step()
+ with torch.no_grad():
+ z.copy_(z.maximum(z_min).minimum(z_max))
+
+
+i = 0
+while True:
+ train(i)
+ if i == cycles:
+ break
+ i += 1
+ print(str(i) + " out of " + str(cycles))
diff --git a/lib.py b/lib.py
new file mode 100644
index 0000000..91553fe
--- /dev/null
+++ b/lib.py
@@ -0,0 +1,179 @@
+import math
+from sys import path
+
+path.append("./taming-transformers")
+from taming.models import cond_transformer, vqgan
+
+import torch
+from torch import nn
+from torch.nn import functional
+
+from omegaconf import OmegaConf
+import kornia.augmentation as K
+
+
+def sinc(x):
+ return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
+
+
+def lanczos(x, a):
+ cond = torch.logical_and(-a < x, x < a)
+ out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
+ return out / out.sum()
+
+
+def ramp(ratio, width):
+ n = math.ceil(width / ratio + 1)
+ out = torch.empty([n])
+ cur = 0
+ for i in range(out.shape[0]):
+ out[i] = cur
+ cur += ratio
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
+
+
+def resample(input, size, align_corners=True):
+ n, c, h, w = input.shape
+ dh, dw = size
+
+ input = input.view([n * c, 1, h, w])
+
+ if dh < h:
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
+ pad_h = (kernel_h.shape[0] - 1) // 2
+ input = functional.pad(input, (0, 0, pad_h, pad_h), "reflect")
+ input = functional.conv2d(input, kernel_h[None, None, :, None])
+
+ if dw < w:
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
+ pad_w = (kernel_w.shape[0] - 1) // 2
+ input = functional.pad(input, (pad_w, pad_w, 0, 0), "reflect")
+ input = functional.conv2d(input, kernel_w[None, None, None, :])
+
+ input = input.view([n, c, h, w])
+ return functional.interpolate(
+ input, size, mode="bicubic", align_corners=align_corners
+ )
+
+
+class ReplaceGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x_forward, x_backward):
+ ctx.shape = x_backward.shape
+ return x_forward
+
+ @staticmethod
+ def backward(ctx, grad_in):
+ return None, grad_in.sum_to_size(ctx.shape)
+
+
+replace_grad = ReplaceGrad.apply
+
+
+class ClampWithGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, min, max):
+ ctx.min = min
+ ctx.max = max
+ ctx.save_for_backward(input)
+ return input.clamp(min, max)
+
+ @staticmethod
+ def backward(ctx, grad_in):
+ (input,) = ctx.saved_tensors
+ return (
+ grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
+ None,
+ None,
+ )
+
+
+clamp_with_grad = ClampWithGrad.apply
+
+
+def vector_quantize(x, codebook):
+ d = (
+ x.pow(2).sum(dim=-1, keepdim=True)
+ + codebook.pow(2).sum(dim=1)
+ - 2 * x @ codebook.T
+ )
+ indices = d.argmin(-1)
+ x_q = functional.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
+ return replace_grad(x_q, x)
+
+
+class Prompt(nn.Module):
+ def __init__(self, embed, weight=1.0, stop=float("-inf")):
+ super().__init__()
+ self.register_buffer("embed", embed)
+ self.register_buffer("weight", torch.as_tensor(weight))
+ self.register_buffer("stop", torch.as_tensor(stop))
+
+ def forward(self, input):
+ input_normed = functional.normalize(input.unsqueeze(1), dim=2)
+ embed_normed = functional.normalize(self.embed.unsqueeze(0), dim=2)
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
+ dists = dists * self.weight.sign()
+ return (
+ self.weight.abs()
+ * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
+ )
+
+
+def parse_prompt(prompt):
+ vals = prompt.rsplit(":", 2)
+ vals = vals + ["", "1", "-inf"][len(vals) :]
+ return vals[0], float(vals[1]), float(vals[2])
+
+
+class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cutn, cut_pow=1.0):
+ super().__init__()
+ self.cut_size = cut_size
+ self.cutn = cutn
+ self.cut_pow = cut_pow
+ self.augs = nn.Sequential(
+ K.RandomHorizontalFlip(p=0.5),
+ # K.RandomSolarize(0.01, 0.01, p=0.7),
+ K.RandomSharpness(0.3, p=0.4),
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
+ K.RandomPerspective(0.2, p=0.4),
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
+ )
+ self.noise_fac = 0.1
+
+ def forward(self, input):
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ cutouts = []
+ for _ in range(self.cutn):
+ size = int(
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
+ )
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
+ batch = self.augs(torch.cat(cutouts, dim=0))
+ if self.noise_fac:
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
+ batch = batch + facs * torch.randn_like(batch)
+ return batch
+
+
+def load_vqgan_model(config_path, checkpoint_path):
+ config = OmegaConf.load(config_path)
+ if config.model.target == "taming.models.vqgan.VQModel":
+ model = vqgan.VQModel(**config.model.params)
+ model.eval().requires_grad_(False)
+ model.init_from_ckpt(checkpoint_path)
+ elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer":
+ parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
+ parent_model.eval().requires_grad_(False)
+ parent_model.init_from_ckpt(checkpoint_path)
+ model = parent_model.first_stage_model
+ else:
+ raise ValueError(f"unknown model type: {config.model.target}")
+ del model.loss
+ return model