diff options
author | Tom Barrett <tom@tombarrett.xyz> | 2021-07-05 17:08:51 +0200 |
---|---|---|
committer | Tom Barrett <tom@tombarrett.xyz> | 2021-07-05 17:08:51 +0200 |
commit | e729c74265e7f8a9db85b85579ccac3f9eab8bed (patch) | |
tree | bc036f4af2816645b8001f81fe5b910e00f34d80 /lib.py |
starting out
Diffstat (limited to 'lib.py')
-rw-r--r-- | lib.py | 179 |
1 files changed, 179 insertions, 0 deletions
@@ -0,0 +1,179 @@ +import math +from sys import path + +path.append("./taming-transformers") +from taming.models import cond_transformer, vqgan + +import torch +from torch import nn +from torch.nn import functional + +from omegaconf import OmegaConf +import kornia.augmentation as K + + +def sinc(x): + return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) + + +def lanczos(x, a): + cond = torch.logical_and(-a < x, x < a) + out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) + return out / out.sum() + + +def ramp(ratio, width): + n = math.ceil(width / ratio + 1) + out = torch.empty([n]) + cur = 0 + for i in range(out.shape[0]): + out[i] = cur + cur += ratio + return torch.cat([-out[1:].flip([0]), out])[1:-1] + + +def resample(input, size, align_corners=True): + n, c, h, w = input.shape + dh, dw = size + + input = input.view([n * c, 1, h, w]) + + if dh < h: + kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) + pad_h = (kernel_h.shape[0] - 1) // 2 + input = functional.pad(input, (0, 0, pad_h, pad_h), "reflect") + input = functional.conv2d(input, kernel_h[None, None, :, None]) + + if dw < w: + kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) + pad_w = (kernel_w.shape[0] - 1) // 2 + input = functional.pad(input, (pad_w, pad_w, 0, 0), "reflect") + input = functional.conv2d(input, kernel_w[None, None, None, :]) + + input = input.view([n, c, h, w]) + return functional.interpolate( + input, size, mode="bicubic", align_corners=align_corners + ) + + +class ReplaceGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, x_forward, x_backward): + ctx.shape = x_backward.shape + return x_forward + + @staticmethod + def backward(ctx, grad_in): + return None, grad_in.sum_to_size(ctx.shape) + + +replace_grad = ReplaceGrad.apply + + +class ClampWithGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, min, max): + ctx.min = min + ctx.max = max + ctx.save_for_backward(input) + return input.clamp(min, max) + + @staticmethod + def backward(ctx, grad_in): + (input,) = ctx.saved_tensors + return ( + grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), + None, + None, + ) + + +clamp_with_grad = ClampWithGrad.apply + + +def vector_quantize(x, codebook): + d = ( + x.pow(2).sum(dim=-1, keepdim=True) + + codebook.pow(2).sum(dim=1) + - 2 * x @ codebook.T + ) + indices = d.argmin(-1) + x_q = functional.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook + return replace_grad(x_q, x) + + +class Prompt(nn.Module): + def __init__(self, embed, weight=1.0, stop=float("-inf")): + super().__init__() + self.register_buffer("embed", embed) + self.register_buffer("weight", torch.as_tensor(weight)) + self.register_buffer("stop", torch.as_tensor(stop)) + + def forward(self, input): + input_normed = functional.normalize(input.unsqueeze(1), dim=2) + embed_normed = functional.normalize(self.embed.unsqueeze(0), dim=2) + dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) + dists = dists * self.weight.sign() + return ( + self.weight.abs() + * replace_grad(dists, torch.maximum(dists, self.stop)).mean() + ) + + +def parse_prompt(prompt): + vals = prompt.rsplit(":", 2) + vals = vals + ["", "1", "-inf"][len(vals) :] + return vals[0], float(vals[1]), float(vals[2]) + + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cutn, cut_pow=1.0): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.cut_pow = cut_pow + self.augs = nn.Sequential( + K.RandomHorizontalFlip(p=0.5), + # K.RandomSolarize(0.01, 0.01, p=0.7), + K.RandomSharpness(0.3, p=0.4), + K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), + K.RandomPerspective(0.2, p=0.4), + K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), + ) + self.noise_fac = 0.1 + + def forward(self, input): + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(self.cutn): + size = int( + torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) + batch = self.augs(torch.cat(cutouts, dim=0)) + if self.noise_fac: + facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) + batch = batch + facs * torch.randn_like(batch) + return batch + + +def load_vqgan_model(config_path, checkpoint_path): + config = OmegaConf.load(config_path) + if config.model.target == "taming.models.vqgan.VQModel": + model = vqgan.VQModel(**config.model.params) + model.eval().requires_grad_(False) + model.init_from_ckpt(checkpoint_path) + elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer": + parent_model = cond_transformer.Net2NetTransformer(**config.model.params) + parent_model.eval().requires_grad_(False) + parent_model.init_from_ckpt(checkpoint_path) + model = parent_model.first_stage_model + else: + raise ValueError(f"unknown model type: {config.model.target}") + del model.loss + return model |