summaryrefslogtreecommitdiff
path: root/lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib.py')
-rw-r--r--lib.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/lib.py b/lib.py
new file mode 100644
index 0000000..91553fe
--- /dev/null
+++ b/lib.py
@@ -0,0 +1,179 @@
+import math
+from sys import path
+
+path.append("./taming-transformers")
+from taming.models import cond_transformer, vqgan
+
+import torch
+from torch import nn
+from torch.nn import functional
+
+from omegaconf import OmegaConf
+import kornia.augmentation as K
+
+
+def sinc(x):
+ return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
+
+
+def lanczos(x, a):
+ cond = torch.logical_and(-a < x, x < a)
+ out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
+ return out / out.sum()
+
+
+def ramp(ratio, width):
+ n = math.ceil(width / ratio + 1)
+ out = torch.empty([n])
+ cur = 0
+ for i in range(out.shape[0]):
+ out[i] = cur
+ cur += ratio
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
+
+
+def resample(input, size, align_corners=True):
+ n, c, h, w = input.shape
+ dh, dw = size
+
+ input = input.view([n * c, 1, h, w])
+
+ if dh < h:
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
+ pad_h = (kernel_h.shape[0] - 1) // 2
+ input = functional.pad(input, (0, 0, pad_h, pad_h), "reflect")
+ input = functional.conv2d(input, kernel_h[None, None, :, None])
+
+ if dw < w:
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
+ pad_w = (kernel_w.shape[0] - 1) // 2
+ input = functional.pad(input, (pad_w, pad_w, 0, 0), "reflect")
+ input = functional.conv2d(input, kernel_w[None, None, None, :])
+
+ input = input.view([n, c, h, w])
+ return functional.interpolate(
+ input, size, mode="bicubic", align_corners=align_corners
+ )
+
+
+class ReplaceGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x_forward, x_backward):
+ ctx.shape = x_backward.shape
+ return x_forward
+
+ @staticmethod
+ def backward(ctx, grad_in):
+ return None, grad_in.sum_to_size(ctx.shape)
+
+
+replace_grad = ReplaceGrad.apply
+
+
+class ClampWithGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, min, max):
+ ctx.min = min
+ ctx.max = max
+ ctx.save_for_backward(input)
+ return input.clamp(min, max)
+
+ @staticmethod
+ def backward(ctx, grad_in):
+ (input,) = ctx.saved_tensors
+ return (
+ grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
+ None,
+ None,
+ )
+
+
+clamp_with_grad = ClampWithGrad.apply
+
+
+def vector_quantize(x, codebook):
+ d = (
+ x.pow(2).sum(dim=-1, keepdim=True)
+ + codebook.pow(2).sum(dim=1)
+ - 2 * x @ codebook.T
+ )
+ indices = d.argmin(-1)
+ x_q = functional.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
+ return replace_grad(x_q, x)
+
+
+class Prompt(nn.Module):
+ def __init__(self, embed, weight=1.0, stop=float("-inf")):
+ super().__init__()
+ self.register_buffer("embed", embed)
+ self.register_buffer("weight", torch.as_tensor(weight))
+ self.register_buffer("stop", torch.as_tensor(stop))
+
+ def forward(self, input):
+ input_normed = functional.normalize(input.unsqueeze(1), dim=2)
+ embed_normed = functional.normalize(self.embed.unsqueeze(0), dim=2)
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
+ dists = dists * self.weight.sign()
+ return (
+ self.weight.abs()
+ * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
+ )
+
+
+def parse_prompt(prompt):
+ vals = prompt.rsplit(":", 2)
+ vals = vals + ["", "1", "-inf"][len(vals) :]
+ return vals[0], float(vals[1]), float(vals[2])
+
+
+class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cutn, cut_pow=1.0):
+ super().__init__()
+ self.cut_size = cut_size
+ self.cutn = cutn
+ self.cut_pow = cut_pow
+ self.augs = nn.Sequential(
+ K.RandomHorizontalFlip(p=0.5),
+ # K.RandomSolarize(0.01, 0.01, p=0.7),
+ K.RandomSharpness(0.3, p=0.4),
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
+ K.RandomPerspective(0.2, p=0.4),
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
+ )
+ self.noise_fac = 0.1
+
+ def forward(self, input):
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ cutouts = []
+ for _ in range(self.cutn):
+ size = int(
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
+ )
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
+ batch = self.augs(torch.cat(cutouts, dim=0))
+ if self.noise_fac:
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
+ batch = batch + facs * torch.randn_like(batch)
+ return batch
+
+
+def load_vqgan_model(config_path, checkpoint_path):
+ config = OmegaConf.load(config_path)
+ if config.model.target == "taming.models.vqgan.VQModel":
+ model = vqgan.VQModel(**config.model.params)
+ model.eval().requires_grad_(False)
+ model.init_from_ckpt(checkpoint_path)
+ elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer":
+ parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
+ parent_model.eval().requires_grad_(False)
+ parent_model.init_from_ckpt(checkpoint_path)
+ model = parent_model.first_stage_model
+ else:
+ raise ValueError(f"unknown model type: {config.model.target}")
+ del model.loss
+ return model