1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
| class NodeModelBase(nn.Module): """ 基于节点和边权重更新节点权重的模型的基础模型。 注意:非线性聚合方式采用add的方式
Args: in_channels (int): 输入通道数 out_channels (int): 输出通道数 in_edgedim (int, optional): 输入的边特征维度 deg_norm (str, optional): 节点正则化常亮计算方法 Choose from [None, 'sm', 'rw']. edge_gate (str, optional): method of applying edge gating mechanism. Choose from [None, 'proj', 'free']. Note: 当设置free时,应该提分that when set to 'free', should also provide `num_edges` as an argument (but then it can only work with fixed edge graph). aggr (str, optional): 信息传递方法. ['add', 'mean', 'max'],默认为'add'. **kwargs: could include `num_edges`, etc.
Input: - x (torch.Tensor): 节点特征矩阵 (N, C_in) - edge_index (torch.LongTensor): COO 格式的边索引,(2, E) - edge_attr (torch.Tensor, optional): 边特征矩阵 (E, D_in)
Output: - xo (torch.Tensor):更新的节点特征 (N, C_out)
where N: 输入节点数量 E: 边数量 C_in/C_out: 输入/输出节点特征的维度 D_in: 输入的边特征维度 """
def __init__(self, in_channels, out_channels, in_edgedim=None, deg_norm='none', edge_gate='none', aggr='add', *args, **kwargs): assert deg_norm in ['none', 'sm', 'rw'] assert edge_gate in ['none', 'proj', 'free'] assert aggr in ['add', 'mean', 'max']
super(NodeModelBase, self).__init__()
self.in_channels = in_channels self.out_channels = out_channels self.in_edgedim = in_edgedim self.deg_norm = deg_norm self.aggr = aggr
if edge_gate == 'proj': self.edge_gate = EdgeGateProj(out_channels, in_edgedim=in_edgedim, bias=True) elif edge_gate == 'free': assert 'num_edges' in kwargs self.edge_gate = EdgeGateFree(kwargs['num_edges']) else: self.register_parameter('edge_gate', None)
@staticmethod def degnorm_const(edge_index=None, num_nodes=None, deg=None, edge_weight=None, method='sm', device=None): """ 计算归一化常数 Calculating the normalization constants based on out-degrees for a graph. `_sm` 使用对称归一化,"symmetric". 更适合用于无向图. `_rw` 使用随即游走归一化(均值),"random walk". 更适合用于有向图.
Procedure: - 检查edge_weight,如果不为None,那么必须同时提供edge_index和num_nodes,计算全部节点的度 - 如果edge_weighe,如果是None,检查是否已经存在deg(节点的度矩阵): - 如果度矩阵存在,那么忽略edge_index和num_nodes - 如果度矩阵不存在,则必须提供edge_index和num_nodes,并计算全部节点的度 Input: - edge_index (torch.Tensor): COO格式的图关系, (2, E),long - num_nodes (int): 节点数量 - deg (torch.Tensor): 节点的度,(N,),float - edge_weight (torch.Tensor): 边权重,(E,),float - method (str): 度标准化方法, choose from ['sm', 'rw'] - device (str or torch.device): 驱动器编号
Output: - norm (torch.Tensor): 基于节点度和边权重的标准化常数. If `method` == 'sm', size (E,); if `method` == 'rw' and `edge_weight` != None, size (E,); if `method` == 'rw' and `edge_weight` == None, size (N,).
where N: 节点数量 E: 边数量 """ assert method in ['sm', 'rw']
if device is None and edge_index is not None: device = edge_index.device
if edge_weight is not None: assert edge_index is not None, 'edge_index must be provided when edge_weight is not None' assert num_nodes is not None, 'num_nodes must be provided when edge_weight is not None'
edge_weight = edge_weight.view(-1) assert edge_weight.size(0) == edge_index.size(1) calculate_deg = True edge_weight_equal = False else: if deg is None: assert edge_index is not None, 'edge_index must be provided when edge_weight is None ' \ 'but deg not provided' assert num_nodes is not None, 'num_nodes must be provided when edge_weight is None ' \ 'but deg not provided' edge_weight = torch.ones((edge_index.size(1),), device=device) calculate_deg = True else: calculate_deg = False edge_weight_equal = True
row, col = edge_index if calculate_deg: deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) if method == 'sm': deg_inv_sqrt = deg.pow(-0.5) elif method == 'rw': deg_inv_sqrt = deg.pow(-1) else: raise ValueError
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 if method == 'sm': norm = (deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] if not edge_weight_equal else deg_inv_sqrt[row] * deg_inv_sqrt[col]) elif method == 'rw': norm = (deg_inv_sqrt[row] * edge_weight if not edge_weight_equal else deg_inv_sqrt) else: raise ValueError
return norm
def forward(self, x, edge_index, edge_attr=None, deg=None, edge_weight=None, *args, **kwargs): return x
def num_parameters(self): if not hasattr(self, 'num_para'): self.num_para = sum([p.nelement() for p in self.parameters()]) return self.num_para
def __repr__(self): return '{} (in_channels: {}, out_channels: {}, in_edgedim: {}, deg_norm: {}, edge_gate: {},' \ 'aggr: {} | number of parameters: {})'.format( self.__class__.__name__, self.in_channels, self.out_channels, self.in_edgedim, self.deg_norm, self.edge_gate.__class__.__name__, self.aggr, self.num_parameters())
|