diff options
| author | ZHENG Yanqin <[email protected]> | 2023-05-25 07:37:53 +0000 |
|---|---|---|
| committer | ZHENG Yanqin <[email protected]> | 2023-05-25 07:37:53 +0000 |
| commit | e9896bd62bb29da00ec00a121374167ad91bfe47 (patch) | |
| tree | d94845574c8ef7473d0204d28b4efd4038035463 /method/AnomalyTransformer.py | |
| parent | fad9aa875c84b38cbb5a6010e104922b1eea7291 (diff) | |
| parent | 4c5734c624705449c6b21c4b2bc5554e7259fdba (diff) | |
readme
See merge request zyq/time_series_anomaly_detection!1
Diffstat (limited to 'method/AnomalyTransformer.py')
| -rw-r--r-- | method/AnomalyTransformer.py | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/method/AnomalyTransformer.py b/method/AnomalyTransformer.py new file mode 100644 index 0000000..9dba21a --- /dev/null +++ b/method/AnomalyTransformer.py @@ -0,0 +1,305 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from math import sqrt +import torch.utils.data as tud + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= '1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular', bias=False) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, dropout=0.0): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x) + + +class TriangularCausalMask(): + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + + +class AnomalyAttention(nn.Module): + def __init__(self, win_size, mask_flag=True, scale=None, attention_dropout=0.0, output_attention=False): + super(AnomalyAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + window_size = win_size + self.distances = torch.zeros((window_size, window_size)).cuda() + for i in range(window_size): + for j in range(window_size): + self.distances[i][j] = abs(i - j) + + def forward(self, queries, keys, values, sigma, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + attn = scale * scores + + sigma = sigma.transpose(1, 2) # B L H -> B H L + window_size = attn.shape[-1] + sigma = torch.sigmoid(sigma * 5) + 1e-5 + sigma = torch.pow(3, sigma) - 1 + sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size) # B H L L + prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).cuda() + prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2)) + + series = self.dropout(torch.softmax(attn, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", series, values) + + if self.output_attention: + return (V.contiguous(), series, prior, sigma) + else: + return (V.contiguous(), None) + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + self.norm = nn.LayerNorm(d_model) + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, + d_keys * n_heads) + self.key_projection = nn.Linear(d_model, + d_keys * n_heads) + self.value_projection = nn.Linear(d_model, + d_values * n_heads) + self.sigma_projection = nn.Linear(d_model, + n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + x = queries + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + sigma = self.sigma_projection(x).view(B, L, H) + + out, series, prior, sigma = self.inner_attention( + queries, + keys, + values, + sigma, + attn_mask + ) + out = out.view(B, L, -1) + + return self.out_projection(out), series, prior, sigma + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn, mask, sigma = self.attention( + x, x, x, + attn_mask=attn_mask + ) + x = x + self.dropout(new_x) + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn, mask, sigma + + +class Encoder(nn.Module): + def __init__(self, attn_layers, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + series_list = [] + prior_list = [] + sigma_list = [] + for attn_layer in self.attn_layers: + x, series, prior, sigma = attn_layer(x, attn_mask=attn_mask) + series_list.append(series) + prior_list.append(prior) + sigma_list.append(sigma) + + if self.norm is not None: + x = self.norm(x) + + return x, series_list, prior_list, sigma_list + + +class Model(nn.Module): + def __init__(self, customs: {}, dataloader: tud.DataLoader): + super(Model, self).__init__() + win_size = int(customs["input_size"]) + enc_in = c_out = dataloader.dataset.train_inputs.shape[-1] + d_model = 512 + n_heads = 8 + e_layers = 3 + d_ff = 512 + dropout = 0.0 + activation = 'gelu' + output_attention = True + self.k = 3 + self.win_size = win_size + + self.name = "AnomalyTransformer" + # Encoding + self.embedding = DataEmbedding(enc_in, d_model, dropout) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + AnomalyAttention(win_size, False, attention_dropout=dropout, output_attention=output_attention), + d_model, n_heads), + d_model, + d_ff, + dropout=dropout, + activation=activation + ) for l in range(e_layers) + ], + norm_layer=torch.nn.LayerNorm(d_model) + ) + + self.projection = nn.Linear(d_model, c_out, bias=True) + + def forward(self, x): + enc_out = self.embedding(x) + enc_out, series, prior, sigmas = self.encoder(enc_out) + enc_out = self.projection(enc_out) + return enc_out, series, prior, sigmas + + @staticmethod + def my_kl_loss(p, q): + res = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001)) + return torch.mean(torch.sum(res, dim=-1), dim=1) + + def loss(self, x, y_true, epoch: int = None, device: str = "cpu"): + output, series, prior, _ = self.forward(x) + series_loss = 0.0 + prior_loss = 0.0 + for u in range(len(prior)): + series_loss += (torch.mean(self.my_kl_loss(series[u], (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size)).detach())) + + torch.mean(self.my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size)).detach(), series[u]))) + + prior_loss += (torch.mean(self.my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size)), series[u].detach())) + + torch.mean(self.my_kl_loss(series[u].detach(), (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size))))) + series_loss = series_loss / len(prior) + prior_loss = prior_loss / len(prior) + rec_loss = nn.MSELoss()(output, x) + + loss1 = rec_loss - self.k * series_loss + loss2 = rec_loss + self.k * prior_loss + + # Minimax strategy + loss1.backward(retain_graph=True) + loss2.backward() + + return loss1.item() + + def detection(self, x, y_true, epoch: int = None, device: str = "cpu"): + temperature = 50 + output, series, prior, _ = self.forward(x) + + loss = torch.mean(nn.MSELoss()(x, output), dim=-1) + + series_loss = 0.0 + prior_loss = 0.0 + for u in range(len(prior)): + if u == 0: + series_loss = self.my_kl_loss(series[u], ( + prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, + self.win_size)).detach()) * temperature + prior_loss = self.my_kl_loss( + (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, + self.win_size)), + series[u].detach()) * temperature + else: + series_loss += self.my_kl_loss(series[u], ( + prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, + self.win_size)).detach()) * temperature + prior_loss += self.my_kl_loss( + (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, + self.win_size)), + series[u].detach()) * temperature + metric = torch.softmax((-series_loss - prior_loss), dim=-1) + + cri = metric * loss + cri = cri.mean(dim=-1) + return cri, None + + |
