"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import math
import sys
from .layers import (
FCNN,
MaskedConv2d,
MaskedLinear,
CNN_Flow_Layer,
Dilation_Block,
unconstrained_RQS,
)
[docs]
class Planar(nn.Module):
def __init__(self):
super(Planar, self).__init__()
self.h = nn.Tanh()
self.softplus = nn.Softplus()
[docs]
def der_h(self, x):
"""Derivative of tanh"""
return 1 - self.h(x) ** 2
[docs]
def forward(self, zk, u, w, b):
zk = zk.unsqueeze(2)
# reparameterize u such that the flow becomes invertible
uw = torch.bmm(w, u)
m_uw = -1.0 + self.softplus(uw)
w_norm_sq = torch.sum(w**2, dim=2, keepdim=True)
u_hat = u + ((m_uw - uw) * w.transpose(2, 1) / w_norm_sq)
# compute flow with u_hat
wzb = torch.bmm(w, zk) + b
z = zk + u_hat * self.h(wzb)
z = z.squeeze(2)
# compute logdetJ
psi = w * self.der_h(wzb)
log_det_jacobian = torch.log(torch.abs(1 + torch.bmm(psi, u_hat)))
log_det_jacobian = log_det_jacobian.squeeze(2).squeeze(1)
return z, log_det_jacobian
[docs]
class Sylvester(nn.Module):
"""
Sylvester normalizing flow.
"""
def __init__(self, num_ortho_vecs):
super(Sylvester, self).__init__()
self.num_ortho_vecs = num_ortho_vecs
self.h = nn.Tanh()
triu_mask = torch.triu(
torch.ones(num_ortho_vecs, num_ortho_vecs), diagonal=1
).unsqueeze(0)
diag_idx = torch.arange(0, num_ortho_vecs).long()
self.register_buffer("triu_mask", Variable(triu_mask))
self.triu_mask.requires_grad = False
self.register_buffer("diag_idx", diag_idx)
[docs]
def der_h(self, x):
return self.der_tanh(x)
[docs]
def der_tanh(self, x):
return 1 - self.h(x) ** 2
def _forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True):
# Amortized flow parameters
zk = zk.unsqueeze(1)
# Save diagonals for log_det_j
diag_r1 = r1[:, self.diag_idx, self.diag_idx]
diag_r2 = r2[:, self.diag_idx, self.diag_idx]
r1_hat = r1
r2_hat = r2
qr2 = torch.bmm(q_ortho, r2_hat.transpose(2, 1))
qr1 = torch.bmm(q_ortho, r1_hat)
r2qzb = torch.bmm(zk, qr2) + b
z = torch.bmm(self.h(r2qzb), qr1.transpose(2, 1)) + zk
z = z.squeeze(1)
# Compute log|det J|
# Output log_det_j in shape (batch_size) instead of (batch_size,1)
diag_j = diag_r1 * diag_r2
diag_j = self.der_h(r2qzb).squeeze(1) * diag_j
diag_j += 1.0
log_diag_j = diag_j.abs().log()
if sum_ldj:
log_det_j = log_diag_j.sum(-1)
else:
log_det_j = log_diag_j
return z, log_det_j
[docs]
def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True):
return self._forward(zk, r1, r2, q_ortho, b, sum_ldj)
[docs]
class TriangularSylvester(nn.Module):
"""
Sylvester normalizing flow with Q=P or Q=I.
"""
def __init__(self, z_size):
super(TriangularSylvester, self).__init__()
self.z_size = z_size
self.h = nn.Tanh()
diag_idx = torch.arange(0, z_size).long()
self.register_buffer("diag_idx", diag_idx)
[docs]
def der_h(self, x):
return self.der_tanh(x)
[docs]
def der_tanh(self, x):
return 1 - self.h(x) ** 2
def _forward(self, zk, r1, r2, b, permute_z=None, sum_ldj=True):
# Amortized flow parameters
zk = zk.unsqueeze(1)
# Save diagonals for log_det_j
diag_r1 = r1[:, self.diag_idx, self.diag_idx]
diag_r2 = r2[:, self.diag_idx, self.diag_idx]
if permute_z is not None:
# permute order of z
z_per = zk[:, :, permute_z]
else:
z_per = zk
r2qzb = torch.bmm(z_per, r2.transpose(2, 1)) + b
z = torch.bmm(self.h(r2qzb), r1.transpose(2, 1))
if permute_z is not None:
# permute order of z again back again
z = z[:, :, permute_z]
z += zk
z = z.squeeze(1)
# Compute log|det J|
# Output log_det_j in shape (batch_size) instead of (batch_size,1)
diag_j = diag_r1 * diag_r2
diag_j = self.der_h(r2qzb).squeeze(1) * diag_j
diag_j += 1.0
log_diag_j = diag_j.abs().log()
if sum_ldj:
log_det_j = log_diag_j.sum(-1)
else:
log_det_j = log_diag_j
return z, log_det_j
[docs]
def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True):
return self._forward(zk, r1, r2, q_ortho, b, sum_ldj)
[docs]
class IAF(nn.Module):
def __init__(
self,
z_size,
num_flows=2,
num_hidden=0,
h_size=50,
forget_bias=1.0,
conv2d=False,
):
super(IAF, self).__init__()
self.z_size = z_size
self.num_flows = num_flows
self.num_hidden = num_hidden
self.h_size = h_size
self.conv2d = conv2d
if not conv2d:
ar_layer = MaskedLinear
else:
ar_layer = MaskedConv2d
self.activation = torch.nn.ELU
# self.activation = torch.nn.ReLU
self.forget_bias = forget_bias
self.flows = []
self.param_list = []
# For reordering z after each flow
flip_idx = torch.arange(self.z_size - 1, -1, -1).long()
self.register_buffer("flip_idx", flip_idx)
for k in range(num_flows):
arch_z = [ar_layer(z_size, h_size), self.activation()]
self.param_list += list(arch_z[0].parameters())
z_feats = torch.nn.Sequential(*arch_z)
arch_zh = []
for j in range(num_hidden):
arch_zh += [ar_layer(h_size, h_size), self.activation()]
self.param_list += list(arch_zh[-2].parameters())
zh_feats = torch.nn.Sequential(*arch_zh)
linear_mean = ar_layer(h_size, z_size, diagonal_zeros=True)
linear_std = ar_layer(h_size, z_size, diagonal_zeros=True)
self.param_list += list(linear_mean.parameters())
self.param_list += list(linear_std.parameters())
if torch.cuda.is_available():
z_feats = z_feats.cuda()
zh_feats = zh_feats.cuda()
linear_mean = linear_mean.cuda()
linear_std = linear_std.cuda()
self.flows.append((z_feats, zh_feats, linear_mean, linear_std))
self.param_list = torch.nn.ParameterList(self.param_list)
[docs]
def forward(self, z, h_context):
logdets = 0.0
for i, flow in enumerate(self.flows):
if (i + 1) % 2 == 0 and not self.conv2d:
# reverse ordering to help mixing
z = z[:, self.flip_idx]
h = flow[0](z)
h = h + h_context
h = flow[1](h)
mean = flow[2](h)
gate = torch.sigmoid(flow[3](h) + self.forget_bias)
z = gate * z + (1 - gate) * mean
logdets += torch.sum(gate.log().view(gate.size(0), -1), 1)
return z, logdets
[docs]
class CNN_Flow(nn.Module):
def __init__(self, dim, cnn_layers, kernel_size, test_mode=0, use_revert=True):
super(CNN_Flow, self).__init__()
# prepare reversion matrix
self.usecuda = False
self.use_revert = use_revert
self.R = Variable(
torch.from_numpy(np.flip(np.eye(dim), axis=1).copy()).float(),
requires_grad=False,
)
if self.usecuda:
self.R = self.R.cuda()
self.layers = nn.ModuleList()
for i in range(cnn_layers):
block = Dilation_Block(dim, kernel_size, test_mode)
self.layers.append(block)
[docs]
def forward(self, x):
logdetSum = 0
output = x
for i in range(len(self.layers)):
output, logdet = self.layers[i](output)
# revert the dimension of the output after each block
if self.use_revert:
z = output.mm(self.R)
logdetSum += logdet
return z, logdetSum
[docs]
class NSF_AR(nn.Module):
"""
Neural spline flow, auto-regressive.
[Durkan et al. 2019]
"""
def __init__(self, dim=15, K=64, B=3, hidden_dim=8, base_network=FCNN):
super().__init__()
self.dim = dim
self.K = K
self.B = B
self.layers = nn.ModuleList()
self.init_param = nn.Parameter(torch.Tensor(3 * K - 1))
for i in range(1, dim):
self.layers += [base_network(i, 3 * K - 1, hidden_dim)]
self.reset_parameters()
[docs]
def reset_parameters(self):
init.uniform_(self.init_param, -1 / 2, 1 / 2)
[docs]
def forward(self, x):
z = torch.zeros_like(x)
logdets = 0 # torch.zeros(z.shape[0])
for i in range(self.dim):
if i == 0:
init_param = self.init_param.expand(x.shape[0], 3 * self.K - 1)
W, H, D = torch.split(init_param, self.K, dim=1)
else:
out = self.layers[i - 1](x[:, :i])
W, H, D = torch.split(out, self.K, dim=1)
W, H = torch.softmax(W, dim=1), torch.softmax(H, dim=1)
W, H = 2 * self.B * W, 2 * self.B * H
D = F.softplus(D)
z[:, i], ld = unconstrained_RQS(
x[:, i], W, H, D, inverse=False, tail_bound=self.B
)
logdets += ld
return z, logdets