Source code for sgl.operators.message_op.over_smooth_distance_op

import torch
import torch.nn.functional as F

from sgl.operators.base_op import MessageOp

[docs]class OverSmoothDistanceWeightedOp(MessageOp): def __init__(self): super(OverSmoothDistanceWeightedOp, self).__init__() self._aggr_type = 'over_smooth_dis_weighted' def _combine(self, feat_list): weight_list = [] features = feat_list[0] norm_fea = torch.norm(features, 2, 1).add(1e-10) for fea in feat_list: norm_cur = torch.norm(fea, 2, 1).add(1e-10) tmp = torch.div((features * fea).sum(1), norm_cur) tmp = torch.div(tmp, norm_fea) weight_list.append(tmp.unsqueeze(-1)) weight = F.softmax(torch.cat(weight_list, dim=1), dim=1) hops = len(feat_list) num_nodes = features.shape[0] output = [] for i in range(num_nodes): fea = 0. for j in range(hops): fea += (weight[i][j]*feat_list[j][i]).unsqueeze(0) output.append(fea) output = torch.cat(output, dim=0) return output