Source code for sgl.operators.message_op.simple_weighted_message_op

import torch
from torch import Tensor

from sgl.operators.base_op import MessageOp
from sgl.operators.utils import one_dim_weighted_add


[docs]class SimpleWeightedMessageOp(MessageOp): # 'alpha' needs one additional parameter 'alpha'; # 'hand_crafted' needs one additional parameter 'weight_list' def __init__(self, start, end, combination_type, *args): super(SimpleWeightedMessageOp, self).__init__(start, end) self._aggr_type = "simple_weighted" if combination_type not in ["alpha", "hand_crafted"]: raise ValueError( "Invalid weighted combination type! Type must be 'alpha' or 'hand_crafted'.") self.__combination_type = combination_type if len(args) != 1: raise ValueError( "Invalid parameter numbers for the simple weighted aggregator!") self.__alpha, self.__weight_list = None, None if combination_type == "alpha": self.__alpha = args[0] if not isinstance(self.__alpha, float): raise TypeError("The alpha must be a float!") elif self.__alpha > 1 or self.__alpha < 0: raise ValueError("The alpha must be a float in [0,1]!") elif combination_type == "hand_crafted": self.__weight_list = args[0] if isinstance(self.__weight_list, list): self.__weight_list = torch.FloatTensor(self.__weight_list) elif not isinstance(self.__weight_list, (list, Tensor)): raise TypeError( "The input weight list must be a list or a tensor!") def _combine(self, feat_list): if self.__combination_type == "alpha": self.__weight_list = [self.__alpha] for _ in range(len(feat_list) - 1): self.__weight_list.append( (1 - self.__alpha) * self.__weight_list[-1]) self.__weight_list = torch.FloatTensor( self.__weight_list[self._start:self._end]) elif self.__combination_type == "hand_crafted": pass else: raise NotImplementedError weighted_feat = one_dim_weighted_add( feat_list[self._start:self._end], weight_list=self.__weight_list) return weighted_feat