# Copyright 2022 Baler Contributors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict, Any, Tuple
import torch
from torch import nn
import torch.utils.data
from torch.nn import functional as F
from torch.autograd import Function, Variable
from ..utils import helper
from . import flows
[docs]
class AE(nn.Module):
# This class is a modified version of the original class by George Dialektakis found at
# https://github.com/Autoencoders-compression-anomaly/Deep-Autoencoders-Data-Compression-GSoC-2021
# Released under the Apache License 2.0 found at https://www.apache.org/licenses/LICENSE-2.0.txt
# Copyright 2021 George Dialektakis
def __init__(self, in_shape, z_dim, *args, **kwargs):
super(AE, self).__init__(*args, **kwargs)
self.activations = {}
self.n_features = in_shape[-1] * in_shape[-2]
self.z_dim = z_dim
self.in_shape = in_shape
# encoder
self.en1 = nn.Linear(self.n_features, 200)
self.dr1 = nn.Dropout(p=0.5)
self.bn1 = nn.BatchNorm1d(200)
self.en2 = nn.Linear(200, 100)
self.dr2 = nn.Dropout(p=0.4)
self.bn2 = nn.BatchNorm1d(100)
self.en3 = nn.Linear(100, 50)
self.dr3 = nn.Dropout(p=0.3)
self.bn3 = nn.BatchNorm1d(50)
self.en4 = nn.Linear(50, self.z_dim)
self.dr4 = nn.Dropout(p=0.2)
self.bn4 = nn.BatchNorm1d(self.z_dim)
self.bn5 = nn.BatchNorm1d(self.n_features)
self.leaky_relu = nn.LeakyReLU()
self.flatten = nn.Flatten()
# decoder
self.de1 = nn.Linear(self.z_dim, 50)
self.de2 = nn.Linear(50, 100)
self.de3 = nn.Linear(100, 200)
self.de4 = nn.Linear(200, self.n_features)
[docs]
def encode(self, x):
h1 = F.leaky_relu(self.en1(x))
h2 = F.leaky_relu(self.en2(h1))
h3 = F.leaky_relu(self.en3(h2))
return self.en4(h3)
[docs]
def decode(self, z):
h4 = F.leaky_relu(self.de1(z))
h5 = F.leaky_relu(self.de2(h4))
h6 = F.leaky_relu(self.de3(h5))
return self.de4(h6)
[docs]
def forward(self, x):
x = self.flatten(x)
z = self.encode(x)
out = self.decode(z)
out = out.view(self.in_shape)
return out, z, z, z, z, z
# Implementation of activation extraction using the forward_hook method
[docs]
def get_hook(self, layer_name):
def hook(model, input, output):
self.activations[layer_name] = output.detach()
return hook
[docs]
def get_layers(self) -> list:
return [self.en1, self.en2, self.en3, self.de1, self.de2, self.de3]
[docs]
def store_hooks(self) -> list:
layers = self.get_layers()
hooks = []
for i in range(len(layers)):
hooks.append(layers[i].register_forward_hook(self.get_hook(str(i))))
return hooks
[docs]
def get_activations(self) -> dict:
for kk in self.activations:
self.activations[kk] = F.leaky_relu(self.activations[kk])
return self.activations
[docs]
def detach_hooks(self, hooks: list) -> None:
for hook in hooks:
hook.remove()
[docs]
class AE_Dropout_BN(AE):
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# encoder
self.enc_nn = nn.Sequential(
self.bn5,
self.en1,
self.dr1,
self.bn1,
self.leaky_relu,
self.en2,
self.dr2,
self.bn2,
self.leaky_relu,
self.en3,
self.dr3,
self.bn3,
self.leaky_relu,
self.en4,
self.dr4,
self.bn4,
)
# decoder
self.dec_nn = nn.Sequential(
self.bn4,
self.de1,
self.leaky_relu,
self.bn3,
self.de2,
self.leaky_relu,
self.bn2,
self.de3,
self.leaky_relu,
self.bn1,
self.de4,
self.bn5,
)
[docs]
def enc_bn(self, x):
return self.enc_nn(x)
[docs]
def dec_bn(self, z):
return self.dec_nn(z)
[docs]
def forward(self, x):
x = self.flatten(x)
z = self.enc_bn(x)
out = self.dec_bn(z)
out = out.view(self.in_shape)
return out, z, z, z, z, z
[docs]
class ConvAE(nn.Module):
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.q_z_mid_dim = 100
self.conv_op_shape = None
self.z_dim = z_dim
self.in_shape = in_shape
# Encoder
# Conv Layers
self.q_z_conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(3, 4), stride=(1,), padding=1),
nn.BatchNorm2d(32),
nn.LeakyReLU(),
nn.Conv2d(32, 16, kernel_size=(3, 1), stride=(1,), padding=1),
nn.BatchNorm2d(16),
nn.LeakyReLU(),
nn.Conv2d(16, 8, kernel_size=(3, 1), stride=(1,), padding=1),
nn.BatchNorm2d(8),
)
# Flatten
self.flatten = nn.Flatten(start_dim=1)
# Get size after flattening
self.q_z_output_dim = self._get_qzconv_output(self.in_shape)
# Linear layers
self.q_z_lin = nn.Sequential(
nn.Linear(self.q_z_output_dim, self.q_z_mid_dim),
nn.BatchNorm1d(self.q_z_mid_dim),
nn.LeakyReLU(),
)
self.q_z_latent = nn.Sequential(
nn.Linear(self.q_z_mid_dim, self.z_dim),
nn.BatchNorm1d(self.z_dim),
)
# Decoder
# Linear layers
self.p_x_lin = nn.Sequential(
nn.Linear(z_dim, self.q_z_mid_dim),
nn.BatchNorm1d(self.q_z_mid_dim),
nn.LeakyReLU(),
nn.Linear(self.q_z_mid_dim, self.q_z_output_dim),
nn.BatchNorm1d(self.q_z_output_dim),
)
# Conv Layers
self.p_x_conv = nn.Sequential(
nn.BatchNorm2d(8),
nn.ConvTranspose2d(8, 16, kernel_size=(3, 1), stride=(1), padding=1),
nn.BatchNorm2d(16),
nn.LeakyReLU(),
nn.ConvTranspose2d(16, 32, kernel_size=(3, 1), stride=(1), padding=1),
nn.BatchNorm2d(32),
nn.LeakyReLU(),
nn.ConvTranspose2d(32, 1, kernel_size=(3, 4), stride=(1), padding=1),
)
def _get_qzconv_output(self, shape):
input = Variable(torch.rand(shape))
output_feat = self._forward_features(input)
n_size = output_feat.size(1)
return int(n_size)
def _forward_features(self, x):
qz = self.q_z_conv(x)
return self.flatten(qz)
[docs]
def encode(self, x):
# Conv
out = self.q_z_conv(x)
self.conv_op_shape = out.shape
# Flatten
out = self.flatten(out)
# Dense
out = self.q_z_lin(out)
# Latent
out = self.q_z_latent(out)
return out
[docs]
def decode(self, z):
# Dense
out = self.p_x_lin(z)
# Unflatten
out = out.view(
self.conv_op_shape[0],
self.conv_op_shape[1],
self.conv_op_shape[2],
self.conv_op_shape[3],
)
# Conv transpose
out = self.p_x_conv(out)
return out
[docs]
def forward(self, x):
z = self.encode(x)
out = self.decode(z)
return out, z, z, z, z, z
[docs]
class ConvVAE(ConvAE):
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# Latent distribution parameters
self.q_z_mean = nn.Linear(self.q_z_mid_dim, self.z_dim)
self.q_z_logvar = nn.Linear(self.q_z_mid_dim, self.z_dim)
# log-det-jacobian = 0 without flows
self.ldj = 0
self.z_size = z_dim
[docs]
def encode(self, x):
# Conv
out = self.q_z_conv(x)
self.conv_op_shape = out.shape
# Flatten
out = self.flatten(out)
# Dense
out = self.q_z_lin(out)
# Latent
mean = self.q_z_mean(out)
logvar = self.q_z_logvar(out)
return out, mean, logvar
[docs]
def decode(self, z):
# Dense
out = self.p_x_lin(z)
# Unflatten
out = out.view(
self.conv_op_shape[0],
self.conv_op_shape[1],
self.conv_op_shape[2],
self.conv_op_shape[3],
)
# Conv transpose
out = self.p_x_conv(out)
return out
[docs]
def reparameterize(self, mean, logvar):
z = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)
return z
[docs]
def forward(self, x):
out, mean, logvar = self.encode(x)
z = self.reparameterize(mean, logvar)
out = self.decode(z)
return out, mean, logvar, self.ldj, z, z
[docs]
class Planar_ConvVAE(ConvVAE):
"""
Variational auto-encoder with planar flows in the decoder.
"""
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# Initialize log-det-jacobian to zero
self.log_det_j = 0
# Flow parameters
flow = flows.Planar
self.num_flows = 6 # args.num_flows
# Amortized flow parameters
self.amor_u = nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size)
self.amor_w = nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size)
self.amor_b = nn.Linear(self.q_z_mid_dim, self.num_flows)
# Normalizing flow layers
for k in range(self.num_flows):
flow_k = flow()
self.add_module("flow_" + str(k), flow_k)
[docs]
def forward(self, x):
self.log_det_j = 0
out, z_mu, z_var = self.encode(x)
batch_size = x.size(0)
# return amortized u an w for all flows
u = self.amor_u(out).view(batch_size, self.num_flows, self.z_size, 1)
w = self.amor_w(out).view(batch_size, self.num_flows, 1, self.z_size)
b = self.amor_b(out).view(batch_size, self.num_flows, 1, 1)
# Sample z_0
z = [self.reparameterize(z_mu, z_var)]
# Normalizing flows
for k in range(self.num_flows):
flow_k = getattr(self, "flow_" + str(k)) # planar.'flow_'+k
z_k, log_det_jacobian = flow_k(
z[k], u[:, k, :, :], w[:, k, :, :], b[:, k, :, :]
)
z.append(z_k)
self.log_det_j += log_det_jacobian
x_decoded = self.decode(z[-1])
return x_decoded, z_mu, z_var, self.log_det_j, z[0], z[-1]
[docs]
class OrthogonalSylvester_ConvVAE(ConvVAE):
"""
Variational auto-encoder with orthogonal flows in the decoder.
"""
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# Initialize log-det-jacobian to zero
self.log_det_j = 0
# Flow parameters
flow = flows.Sylvester
self.num_flows = 4 # args.num_flows
self.num_ortho_vecs = 5 # args.num_ortho_vecs
assert (self.num_ortho_vecs <= self.z_size) and (self.num_ortho_vecs > 0)
# Orthogonalization parameters
if self.num_ortho_vecs == self.z_size:
self.cond = 1.0e-5
else:
self.cond = 1.0e-6
self.steps = 100
identity = torch.eye(self.num_ortho_vecs, self.num_ortho_vecs)
# Add batch dimension
identity = identity.unsqueeze(0)
# Put identity in buffer so that it will be moved to GPU if needed by any call of .cuda
self.register_buffer("_eye", Variable(identity))
self._eye.requires_grad = False
# Masks needed for triangular R1 and R2.
triu_mask = torch.triu(
torch.ones(self.num_ortho_vecs, self.num_ortho_vecs), diagonal=1
)
triu_mask = triu_mask.unsqueeze(0).unsqueeze(3)
diag_idx = torch.arange(0, self.num_ortho_vecs).long()
self.register_buffer("triu_mask", Variable(triu_mask))
self.triu_mask.requires_grad = False
self.register_buffer("diag_idx", diag_idx)
# Amortized flow parameters
# Diagonal elements of R1 * R2 have to satisfy -1 < R1 * R2 for flow to be invertible
self.diag_activation = nn.Tanh()
self.amor_d = nn.Linear(
self.q_z_mid_dim,
self.num_flows * self.num_ortho_vecs * self.num_ortho_vecs,
)
self.amor_diag1 = nn.Sequential(
nn.Linear(self.q_z_mid_dim, self.num_flows * self.num_ortho_vecs),
self.diag_activation,
)
self.amor_diag2 = nn.Sequential(
nn.Linear(self.q_z_mid_dim, self.num_flows * self.num_ortho_vecs),
self.diag_activation,
)
self.amor_q = nn.Linear(
self.q_z_mid_dim, self.num_flows * self.z_size * self.num_ortho_vecs
)
self.amor_b = nn.Linear(self.q_z_mid_dim, self.num_flows * self.num_ortho_vecs)
# Normalizing flow layers
for k in range(self.num_flows):
flow_k = flow(self.num_ortho_vecs)
self.add_module("flow_" + str(k), flow_k)
[docs]
def batch_construct_orthogonal(self, q):
# Reshape to shape (num_flows * batch_size, z_size * num_ortho_vecs)
q = q.view(-1, self.z_size * self.num_ortho_vecs)
norm = torch.norm(q, p=2, dim=1, keepdim=True)
amat = torch.div(q, norm)
dim0 = amat.size(0)
amat = amat.resize(dim0, self.z_size, self.num_ortho_vecs)
max_norm = 0
# Iterative orthogonalization
for s in range(self.steps):
tmp = torch.bmm(amat.transpose(2, 1), amat)
tmp = self._eye - tmp
tmp = self._eye + 0.5 * tmp
amat = torch.bmm(amat, tmp)
# Testing for convergence
test = torch.bmm(amat.transpose(2, 1), amat) - self._eye
norms2 = torch.sum(torch.norm(test, p=2, dim=2) ** 2, dim=1)
norms = torch.sqrt(norms2)
max_norm = torch.max(norms).data
if max_norm <= self.cond:
break
if max_norm > self.cond:
print("\nWARNING: orthogonalization not complete")
print("\t Final max norm =", max_norm)
# print()
# Reshaping: first dimension is batch_size
amat = amat.view(-1, self.num_flows, self.z_size, self.num_ortho_vecs)
amat = amat.transpose(0, 1)
return amat
[docs]
def forward(self, x):
self.log_det_j = 0
out, z_mu, z_var = self.encode(x)
batch_size = x.size(0)
# Amortized r1, r2, q, b for all flows
full_d = self.amor_d(out)
diag1 = self.amor_diag1(out)
diag2 = self.amor_diag2(out)
full_d = full_d.resize(
batch_size, self.num_ortho_vecs, self.num_ortho_vecs, self.num_flows
)
diag1 = diag1.resize(batch_size, self.num_ortho_vecs, self.num_flows)
diag2 = diag2.resize(batch_size, self.num_ortho_vecs, self.num_flows)
r1 = full_d * self.triu_mask
r2 = full_d.transpose(2, 1) * self.triu_mask
r1[:, self.diag_idx, self.diag_idx, :] = diag1
r2[:, self.diag_idx, self.diag_idx, :] = diag2
q = self.amor_q(out)
b = self.amor_b(out)
# Resize flow parameters to divide over K flows
b = b.resize(batch_size, 1, self.num_ortho_vecs, self.num_flows)
# Orthogonalize all q matrices
q_ortho = self.batch_construct_orthogonal(q)
# Sample z_0
z = [self.reparameterize(z_mu, z_var)]
# Normalizing flows
for k in range(self.num_flows):
flow_k = getattr(self, "flow_" + str(k))
z_k, log_det_jacobian = flow_k(
z[k], r1[:, :, :, k], r2[:, :, :, k], q_ortho[k, :, :, :], b[:, :, :, k]
)
z.append(z_k)
self.log_det_j += log_det_jacobian
x_decoded = self.decode(z[-1])
return x_decoded, z_mu, z_var, self.log_det_j, z[0], z[-1]
[docs]
class HouseholderSylvester_ConvVAE(ConvVAE):
"""
Variational auto-encoder with householder sylvester flows in the decoder.
"""
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# Initialize log-det-jacobian to zero
self.log_det_j = 0
# Flow parameters
flow = flows.Sylvester
self.num_flows = 4 # args.num_flows
self.num_householder = 8 # args.num_householder
assert self.num_householder > 0
identity = torch.eye(self.z_size, self.z_size)
# Add batch dimension
identity = identity.unsqueeze(0)
# Put identity in buffer so that it will be moved to GPU if needed by any call of .cuda
self.register_buffer("_eye", Variable(identity))
self._eye.requires_grad = False
# Masks needed for triangular r1 and r2.
triu_mask = torch.triu(torch.ones(self.z_size, self.z_size), diagonal=1)
triu_mask = triu_mask.unsqueeze(0).unsqueeze(3)
diag_idx = torch.arange(0, self.z_size).long()
self.register_buffer("triu_mask", Variable(triu_mask))
self.triu_mask.requires_grad = False
self.register_buffer("diag_idx", diag_idx)
# Amortized flow parameters
# Diagonal elements of r1 * r2 have to satisfy -1 < r1 * r2 for flow to be invertible
self.diag_activation = nn.Tanh()
self.amor_d = nn.Linear(
self.q_z_mid_dim, self.num_flows * self.z_size * self.z_size
)
self.amor_diag1 = nn.Sequential(
nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size),
self.diag_activation,
)
self.amor_diag2 = nn.Sequential(
nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size),
self.diag_activation,
)
self.amor_q = nn.Linear(
self.q_z_mid_dim, self.num_flows * self.z_size * self.num_householder
)
self.amor_b = nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size)
# Normalizing flow layers
for k in range(self.num_flows):
flow_k = flow(self.z_size)
self.add_module("flow_" + str(k), flow_k)
[docs]
def batch_construct_orthogonal(self, q):
# Reshape to shape (num_flows * batch_size * num_householder, z_size)
q = q.view(-1, self.z_size)
norm = torch.norm(q, p=2, dim=1, keepdim=True)
v = torch.div(q, norm)
# Calculate Householder Matrices
vvT = torch.bmm(v.unsqueeze(2), v.unsqueeze(1))
amat = self._eye - 2 * vvT
# Reshaping: first dimension is batch_size * num_flows
amat = amat.view(-1, self.num_householder, self.z_size, self.z_size)
tmp = amat[:, 0]
for k in range(1, self.num_householder):
tmp = torch.bmm(amat[:, k], tmp)
amat = tmp.view(-1, self.num_flows, self.z_size, self.z_size)
amat = amat.transpose(0, 1)
return amat
[docs]
def forward(self, x):
self.log_det_j = 0
batch_size = x.size(0)
out, z_mu, z_var = self.encode(x)
batch_size = x.size(0)
# Amortized r1, r2, q, b for all flows
full_d = self.amor_d(out)
diag1 = self.amor_diag1(out)
diag2 = self.amor_diag2(out)
full_d = full_d.resize(batch_size, self.z_size, self.z_size, self.num_flows)
diag1 = diag1.resize(batch_size, self.z_size, self.num_flows)
diag2 = diag2.resize(batch_size, self.z_size, self.num_flows)
r1 = full_d * self.triu_mask
r2 = full_d.transpose(2, 1) * self.triu_mask
r1[:, self.diag_idx, self.diag_idx, :] = diag1
r2[:, self.diag_idx, self.diag_idx, :] = diag2
q = self.amor_q(out)
b = self.amor_b(out)
# Resize flow parameters to divide over K flows
b = b.resize(batch_size, 1, self.z_size, self.num_flows)
# Orthogonalize all q matrices
q_ortho = self.batch_construct_orthogonal(q)
# Sample z_0
z = [self.reparameterize(z_mu, z_var)]
# Normalizing flows
for k in range(self.num_flows):
flow_k = getattr(self, "flow_" + str(k))
q_k = q_ortho[k]
z_k, log_det_jacobian = flow_k(
z[k], r1[:, :, :, k], r2[:, :, :, k], q_k, b[:, :, :, k], sum_ldj=True
)
z.append(z_k)
self.log_det_j += log_det_jacobian
x_decoded = self.decode(z[-1])
return x_decoded, z_mu, z_var, self.log_det_j, z[0], z[-1]
[docs]
class TriangularSylvester_ConvVAE(ConvVAE):
"""
Variational auto-encoder with triangular sylvester flows in the decoder.
"""
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# Initialize log-det-jacobian to zero
self.log_det_j = 0
# Flow parameters
flow = flows.TriangularSylvester
self.num_flows = 4 # args.num_flows
# permuting indices corresponding to Q=P (permutation matrix) for every other flow
flip_idx = torch.arange(self.z_size - 1, -1, -1).long()
self.register_buffer("flip_idx", flip_idx)
# Masks needed for triangular r1 and r2.
triu_mask = torch.triu(torch.ones(self.z_size, self.z_size), diagonal=1)
triu_mask = triu_mask.unsqueeze(0).unsqueeze(3)
diag_idx = torch.arange(0, self.z_size).long()
self.register_buffer("triu_mask", Variable(triu_mask))
self.triu_mask.requires_grad = False
self.register_buffer("diag_idx", diag_idx)
# Amortized flow parameters
# Diagonal elements of r1 * r2 have to satisfy -1 < r1 * r2 for flow to be invertible
self.diag_activation = nn.Tanh()
self.amor_d = nn.Linear(
self.q_z_mid_dim, self.num_flows * self.z_size * self.z_size
)
self.amor_diag1 = nn.Sequential(
nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size),
self.diag_activation,
)
self.amor_diag2 = nn.Sequential(
nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size),
self.diag_activation,
)
self.amor_b = nn.Linear(self.q_z_mid_dim, self.num_flows * self.z_size)
# Normalizing flow layers
for k in range(self.num_flows):
flow_k = flow(self.z_size)
self.add_module("flow_" + str(k), flow_k)
[docs]
def forward(self, x):
self.log_det_j = 0
out, z_mu, z_var = self.encode(x)
batch_size = x.size(0)
# Amortized r1, r2, b for all flows
full_d = self.amor_d(out)
diag1 = self.amor_diag1(out)
diag2 = self.amor_diag2(out)
full_d = full_d.resize(batch_size, self.z_size, self.z_size, self.num_flows)
diag1 = diag1.resize(batch_size, self.z_size, self.num_flows)
diag2 = diag2.resize(batch_size, self.z_size, self.num_flows)
r1 = full_d * self.triu_mask
r2 = full_d.transpose(2, 1) * self.triu_mask
r1[:, self.diag_idx, self.diag_idx, :] = diag1
r2[:, self.diag_idx, self.diag_idx, :] = diag2
b = self.amor_b(out)
# Resize flow parameters to divide over K flows
b = b.resize(batch_size, 1, self.z_size, self.num_flows)
# Sample z_0
z = [self.reparameterize(z_mu, z_var)]
# Normalizing flows
for k in range(self.num_flows):
flow_k = getattr(self, "flow_" + str(k))
if k % 2 == 1:
# Alternate with reorderering z for triangular flow
permute_z = self.flip_idx
else:
permute_z = None
z_k, log_det_jacobian = flow_k(
z[k],
r1[:, :, :, k],
r2[:, :, :, k],
b[:, :, :, k],
permute_z,
sum_ldj=True,
)
z.append(z_k)
self.log_det_j += log_det_jacobian
x_decoded = self.decode(z[-1])
return x_decoded, z_mu, z_var, self.log_det_j, z[0], z[-1]
[docs]
class IAF_ConvVAE(ConvVAE):
"""
Variational auto-encoder with inverse autoregressive flows in the decoder.
"""
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# Initialize log-det-jacobian to zero
self.log_det_j = 0
self.h_size = 60 # args.made_h_size
self.h_context = nn.Linear(self.q_z_mid_dim, self.h_size)
# Flow parameters
self.num_flows = 4
self.flow = flows.IAF(
z_size=self.z_size,
num_flows=self.num_flows,
num_hidden=1,
h_size=self.h_size,
conv2d=False,
)
[docs]
def encode(self, x):
# Conv
out = self.q_z_conv(x)
self.conv_op_shape = out.shape
# Flatten
out = self.flatten(out)
# Dense
out = self.q_z_lin(out)
# Latent
mean = self.q_z_mean(out)
logvar = self.q_z_logvar(out)
# context from previous layer
h_context = self.h_context(out)
return mean, logvar, h_context
[docs]
def forward(self, x):
# mean and variance of z
z_mu, z_var, h_context = self.encode(x)
# sample z
z_0 = self.reparameterize(z_mu, z_var)
# iaf flows
z_k, self.log_det_j = self.flow(z_0, h_context)
# decode
x_decoded = self.decode(z_k)
return x_decoded, z_mu, z_var, self.log_det_j, z_0, z_k
[docs]
class ConvFlow_ConvVAE(ConvVAE):
"""
Variational auto-encoder with convolutional flows in the decoder.
"""
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
# Initialize log-det-jacobian to zero
self.log_det_j = 0
self.num_flows = 4 # args.num_flows # 6 for chan1
self.kernel_size = 7 # args.convFlow_kernel_size
self.test_mode = False
flow_k = flows.CNN_Flow
# Normalizing flow layers
self.flow = flow_k(
dim=self.z_size,
cnn_layers=self.num_flows,
kernel_size=self.kernel_size,
test_mode=self.test_mode,
)
[docs]
def forward(self, x):
# mean and variance of z
out, z_mu, z_var = self.encode(x)
# sample z
z_0 = self.reparameterize(z_mu, z_var)
# Normalizing flows
z_k, logdet = self.flow(z_0)
# decode
x_decoded = self.decode(z_k)
return x_decoded, z_mu, z_var, self.log_det_j, z_0, z_k
[docs]
class NSFAR_ConvVAE(ConvVAE):
"""
Variational auto-encoder with auto-regressive neural spline flows in the decoder.
"""
def __init__(self, in_shape, z_dim, *args, **kwargs):
super().__init__(in_shape, z_dim, *args, **kwargs)
self.log_det_j = 0
self.dim = self.z_size
self.num_flows = 4 # args.num_flows
flow = flows.NSF_AR
# Normalizing flow layers
for k in range(self.num_flows):
flow_k = flow(dim=self.dim)
self.add_module("flow_" + str(k), flow_k)
[docs]
def forward(self, x):
# mean and variance of z
out, z_mu, z_var = self.encode(x)
# sample z
z = [self.reparameterize(z_mu, z_var)]
# Normalizing flows
for k in range(self.num_flows):
flow_k = getattr(self, "flow_" + str(k))
z_k, log_det_jacobian = flow_k(z[k])
z.append(z_k)
self.log_det_j += log_det_jacobian
# decode
x_decoded = self.decode(z[-1])
return x_decoded, z_mu, z_var, self.log_det_j, z[0], z[-1]