Deep Learning

Autoconstrain Model

ju_young 2023. 5. 24. 18:28
728x90

Task Description

  • partner prediction: 스케치의 노드와 constraints를 나타내는 그래프가 주어지면 새로운 constraint를 예측하기위해 현재 노드에 연결해야할 노드를 예측
  • constant lable prediction: partner prediction의 결과와 현재 constraint의 target partner가 주어지면 현재 constraint의 type을 예측

Model Description

모델은 다음 세 가지로 나누어진다.

  1. input representation: embedding the features from the primitives and constraints
  2. message passing: transforming these features using the graph structure
  3. readout: outputs probabilities for the specific tasks according to the transformed features

1. input representation

  • input: primitive type(아래 코드에서 DenseSparsePreEmbedding의 feature)
  1. 모든  parameter를 embedding한 것(self.fixed_embedding)
  2. sparse parameter를 embedding하여 평균낸 것(NumericalFeatureEncoding->NumericalFeaturesEmbedding)
  3. concatenate하고 dense(linear) layer를 통해 projection(self.dense_merge)
DenseSparsePreEmbedding(
        target.TargetType, {
            k.name: torch.nn.Sequential(
                NumericalFeatureEncoding(fd.values(), embedding_dim),
                NumericalFeaturesEmbedding(embedding_dim)
            )
            for k, fd in feature_dims.items()
        },
        len(target.NODE_TYPES), embedding_dim)

class DenseSparsePreEmbedding(torch.nn.Module):
    def __init__(self, target_type, feature_embeddings, fixed_embedding_cardinality, fixed_embedding_dim,
                 sparse_embedding_dim=None, embedding_dim=None):
        super(DenseSparsePreEmbedding, self).__init__()
        ...

        self.feature_embeddings = torch.nn.ModuleDict(feature_embeddings)
        self.fixed_embedding = torch.nn.Embedding(fixed_embedding_cardinality, fixed_embedding_dim)
        self.dense_merge = sg_nn.ConcatenateLinear(fixed_embedding_dim, sparse_embedding_dim, embedding_dim)
        
class NumericalFeatureEncoding(torch.nn.Module):
    def __init__(self, feature_dims, embedding_dim):
        super(NumericalFeatureEncoding, self).__init__()
        ...
        self.register_buffer(
            'feature_offsets',
            torch.cumsum(torch.tensor([0] + self.feature_dims[:-1], dtype=torch.int64), dim=0))
        self.embeddings = torch.nn.Embedding(
            sum(feature_dims), embedding_dim, sparse=False)

    def forward(self, features):
        return self.embeddings(features + self.feature_offsets)


class NumericalFeaturesEmbedding(torch.nn.Module):
    def __init__(self, embedding_dim):
        super(NumericalFeaturesEmbedding, self).__init__()
        self.embedding_dim = embedding_dim

    def forward(self, embeddings):
        return embeddings.mean(axis=-2)

 

2. Message passing

위에서 embedding된 node embedding는 GRU를 통해 변환

node_pre_embedding_transform = torch.nn.GRU(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=3, bidirectional=True)
node_pre_embedding_transformed, state = global_entity_embedding(node_pre_embedding_packed)

이후 message passing network를 사용하여 변환

  • c$(u, v)$ = 이전에 계산된 constraint
  • $m_u^{(0)}$ = 이전에 계산된 primitive
  • $f_e$ = linear layer
  • $f_n$ = GRU cell

코드에서는 edge embedding도 추가해준 것을 확인할 수 있었다.

edge_embedding = torch.nn.Embedding(len(target.EDGE_TYPES), embedding_dim)
message_passing = sg_nn.MessagePassingNetwork(
            depth, torch.nn.GRUCell(embedding_dim, embedding_dim),
            sg_nn.ConcatenateLinear(embedding_dim, embedding_dim, embedding_dim))
edge_pre_embedding = self.edge_embedding(graph.edge_features)
node_post_embedding = self.message_passing(node_pre_embedding_graph, graph.incidence, (edge_pre_embedding,))

graph의 global representatation은 위에서 GRU를 통해 나온 마지막 hidden_state(state)과 message_passing을 통해 나온 node_post_embedding에 linear layer와 Sigmoid를 적용한 것을 concatenate하여 한번 더 linear layer를 통과하여 계산된다. 코드에서는 linear layer + Sigmoid 뿐만 아니라 linear layer만 통과하는 것도 추가로 해준다. 그리고 concatenation 되기 전에는 average pooling이 추가로 적용된 것을 확인할 수 있다. 

class GraphPostEmbedding(torch.nn.Module):
    def __init__(self, hidden_size, graph_embedding_size=None):
        super(GraphPostEmbedding, self).__init__()

        if graph_embedding_size is None:
            graph_embedding_size = 2 * hidden_size

        self.node_gating_net = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 1),
            torch.nn.Sigmoid()
        )
        self.node_to_graph_net = torch.nn.Linear(hidden_size, graph_embedding_size)

    def forward(self, node_embedding, graph):
        scopes = graph_model.scopes_from_offsets(graph.node_offsets)

        transformed_embedding = self.node_gating_net(node_embedding) * self.node_to_graph_net(node_embedding)

        graph_embedding = sg_nn.functional.segment_avg_pool1d(
            transformed_embedding, scopes) * graph.node_counts.unsqueeze(-1)

        return graph_embedding

graph_post_embedding = message_passing.GraphPostEmbedding(embedding_dim)
merge_global_embedding = sg_nn.ConcatenateLinear(2 * embedding_dim, 2 * embedding_dim, embedding_dim)

global_embedding = torch.flatten(torch.transpose(
                state.view(3, 2, -1, self.embedding_dim)[-1], 0, 1),
                start_dim=1)
graph_post_embedding = self.graph_post_embedding(node_post_embedding, graph)

merged_global_embedding = self.merge_global_embedding(global_embedding, graph_post_embedding)

3. Readout

Readout은 partner를 예측하는 단계이다.

  • 앞에서 구한 graph_post_embedding(아래 코드에서 node_post_embedding)과 merged_global_embedding은 concatenation이 되는데 node_post_embedding는 taget_embedding을 정의한다.
  • fully-connected two layer를 통과하여 logit(probability) 값을 얻는다.
class EdgePartnerNetwork(torch.nn.Module):
    def __init__(self, readout_net):
        super(EdgePartnerNetwork, self).__init__()
        self.readout_net = readout_net

    def forward(self, node_embedding, graph_embedding, graph):
        target_idx = graph.node_offsets[1:] - 1

        target_embeddings = (node_embedding
                             .index_select(0, target_idx)
                             .repeat_interleave(graph.node_counts, 0))

        graph_embedding = (graph_embedding
                           .repeat_interleave(graph.node_counts, 0))

        edge_partner_input = torch.cat((node_embedding, target_embeddings, graph_embedding), dim=-1)
        logits = self.readout_net(edge_partner_input).squeeze(-1)
        return logits

edge_partner_network = EdgePartnerNetwork(
            torch.nn.Sequential(
                torch.nn.Linear(3 * embedding_dim, embedding_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(embedding_dim, 1)))
edge_partner_logits = self.edge_partner_network(
                node_post_embedding, merged_global_embedding, graph)
  • node_post_embedding과 partner node에 해당하는 embedding, global representation은 concatenation
  • three layer fully-connected neural network를 통과하여 edge type을 예측하는 logit(probability) 값을 얻는다.
edge_label = torch.nn.Sequential(
            torch.nn.Linear(3 * embedding_dim, embedding_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(embedding_dim, embedding_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(embedding_dim, len(target.EDGE_TYPES_PREDICTED)))

node_current_post_embedding_label = node_post_embedding.index_select(
            0, graph.node_offsets[1:][partner_index.index] - 1)
node_partner_post_embedding_label = node_post_embedding.index_select(0, partner_index.values)
merged_global_embedding_label = global_embedding.index_select(0, partner_index.index)

edge_label_input = torch.cat(
    (node_current_post_embedding_label, node_partner_post_embedding_label, merged_global_embedding_label), dim=-1)

edge_label(edge_label_input)

최종적으로 partner를 예측하는 logit과 edge type을 예측하는 logit이 return된다.

return {
            'edge_partner_logits': edge_partner_logits,
            'edge_label_logits': edge_label_logits
        }

 

[reference]

https://github.com/PrincetonLIPS/SketchGraphs

728x90