728x90
Task Description
- partner prediction: 스케치의 노드와 constraints를 나타내는 그래프가 주어지면 새로운 constraint를 예측하기위해 현재 노드에 연결해야할 노드를 예측
- constant lable prediction: partner prediction의 결과와 현재 constraint의 target partner가 주어지면 현재 constraint의 type을 예측
Model Description
모델은 다음 세 가지로 나누어진다.
- input representation: embedding the features from the primitives and constraints
- message passing: transforming these features using the graph structure
- readout: outputs probabilities for the specific tasks according to the transformed features
1. input representation
- input: primitive type(아래 코드에서 DenseSparsePreEmbedding의 feature)
- 모든 parameter를 embedding한 것(self.fixed_embedding)
- sparse parameter를 embedding하여 평균낸 것(NumericalFeatureEncoding->NumericalFeaturesEmbedding)
- concatenate하고 dense(linear) layer를 통해 projection(self.dense_merge)
DenseSparsePreEmbedding(
target.TargetType, {
k.name: torch.nn.Sequential(
NumericalFeatureEncoding(fd.values(), embedding_dim),
NumericalFeaturesEmbedding(embedding_dim)
)
for k, fd in feature_dims.items()
},
len(target.NODE_TYPES), embedding_dim)
class DenseSparsePreEmbedding(torch.nn.Module):
def __init__(self, target_type, feature_embeddings, fixed_embedding_cardinality, fixed_embedding_dim,
sparse_embedding_dim=None, embedding_dim=None):
super(DenseSparsePreEmbedding, self).__init__()
...
self.feature_embeddings = torch.nn.ModuleDict(feature_embeddings)
self.fixed_embedding = torch.nn.Embedding(fixed_embedding_cardinality, fixed_embedding_dim)
self.dense_merge = sg_nn.ConcatenateLinear(fixed_embedding_dim, sparse_embedding_dim, embedding_dim)
class NumericalFeatureEncoding(torch.nn.Module):
def __init__(self, feature_dims, embedding_dim):
super(NumericalFeatureEncoding, self).__init__()
...
self.register_buffer(
'feature_offsets',
torch.cumsum(torch.tensor([0] + self.feature_dims[:-1], dtype=torch.int64), dim=0))
self.embeddings = torch.nn.Embedding(
sum(feature_dims), embedding_dim, sparse=False)
def forward(self, features):
return self.embeddings(features + self.feature_offsets)
class NumericalFeaturesEmbedding(torch.nn.Module):
def __init__(self, embedding_dim):
super(NumericalFeaturesEmbedding, self).__init__()
self.embedding_dim = embedding_dim
def forward(self, embeddings):
return embeddings.mean(axis=-2)
2. Message passing
위에서 embedding된 node embedding는 GRU를 통해 변환
node_pre_embedding_transform = torch.nn.GRU(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=3, bidirectional=True)
node_pre_embedding_transformed, state = global_entity_embedding(node_pre_embedding_packed)
이후 message passing network를 사용하여 변환
- c$(u, v)$ = 이전에 계산된 constraint
- $m_u^{(0)}$ = 이전에 계산된 primitive
- $f_e$ = linear layer
- $f_n$ = GRU cell
코드에서는 edge embedding도 추가해준 것을 확인할 수 있었다.
edge_embedding = torch.nn.Embedding(len(target.EDGE_TYPES), embedding_dim)
message_passing = sg_nn.MessagePassingNetwork(
depth, torch.nn.GRUCell(embedding_dim, embedding_dim),
sg_nn.ConcatenateLinear(embedding_dim, embedding_dim, embedding_dim))
edge_pre_embedding = self.edge_embedding(graph.edge_features)
node_post_embedding = self.message_passing(node_pre_embedding_graph, graph.incidence, (edge_pre_embedding,))
graph의 global representatation은 위에서 GRU를 통해 나온 마지막 hidden_state(state)과 message_passing을 통해 나온 node_post_embedding에 linear layer와 Sigmoid를 적용한 것을 concatenate하여 한번 더 linear layer를 통과하여 계산된다. 코드에서는 linear layer + Sigmoid 뿐만 아니라 linear layer만 통과하는 것도 추가로 해준다. 그리고 concatenation 되기 전에는 average pooling이 추가로 적용된 것을 확인할 수 있다.
class GraphPostEmbedding(torch.nn.Module):
def __init__(self, hidden_size, graph_embedding_size=None):
super(GraphPostEmbedding, self).__init__()
if graph_embedding_size is None:
graph_embedding_size = 2 * hidden_size
self.node_gating_net = torch.nn.Sequential(
torch.nn.Linear(hidden_size, 1),
torch.nn.Sigmoid()
)
self.node_to_graph_net = torch.nn.Linear(hidden_size, graph_embedding_size)
def forward(self, node_embedding, graph):
scopes = graph_model.scopes_from_offsets(graph.node_offsets)
transformed_embedding = self.node_gating_net(node_embedding) * self.node_to_graph_net(node_embedding)
graph_embedding = sg_nn.functional.segment_avg_pool1d(
transformed_embedding, scopes) * graph.node_counts.unsqueeze(-1)
return graph_embedding
graph_post_embedding = message_passing.GraphPostEmbedding(embedding_dim)
merge_global_embedding = sg_nn.ConcatenateLinear(2 * embedding_dim, 2 * embedding_dim, embedding_dim)
global_embedding = torch.flatten(torch.transpose(
state.view(3, 2, -1, self.embedding_dim)[-1], 0, 1),
start_dim=1)
graph_post_embedding = self.graph_post_embedding(node_post_embedding, graph)
merged_global_embedding = self.merge_global_embedding(global_embedding, graph_post_embedding)
3. Readout
Readout은 partner를 예측하는 단계이다.
- 앞에서 구한 graph_post_embedding(아래 코드에서 node_post_embedding)과 merged_global_embedding은 concatenation이 되는데 node_post_embedding는 taget_embedding을 정의한다.
- fully-connected two layer를 통과하여 logit(probability) 값을 얻는다.
class EdgePartnerNetwork(torch.nn.Module):
def __init__(self, readout_net):
super(EdgePartnerNetwork, self).__init__()
self.readout_net = readout_net
def forward(self, node_embedding, graph_embedding, graph):
target_idx = graph.node_offsets[1:] - 1
target_embeddings = (node_embedding
.index_select(0, target_idx)
.repeat_interleave(graph.node_counts, 0))
graph_embedding = (graph_embedding
.repeat_interleave(graph.node_counts, 0))
edge_partner_input = torch.cat((node_embedding, target_embeddings, graph_embedding), dim=-1)
logits = self.readout_net(edge_partner_input).squeeze(-1)
return logits
edge_partner_network = EdgePartnerNetwork(
torch.nn.Sequential(
torch.nn.Linear(3 * embedding_dim, embedding_dim),
torch.nn.ReLU(),
torch.nn.Linear(embedding_dim, 1)))
edge_partner_logits = self.edge_partner_network(
node_post_embedding, merged_global_embedding, graph)
- node_post_embedding과 partner node에 해당하는 embedding, global representation은 concatenation
- three layer fully-connected neural network를 통과하여 edge type을 예측하는 logit(probability) 값을 얻는다.
edge_label = torch.nn.Sequential(
torch.nn.Linear(3 * embedding_dim, embedding_dim),
torch.nn.ReLU(),
torch.nn.Linear(embedding_dim, embedding_dim),
torch.nn.ReLU(),
torch.nn.Linear(embedding_dim, len(target.EDGE_TYPES_PREDICTED)))
node_current_post_embedding_label = node_post_embedding.index_select(
0, graph.node_offsets[1:][partner_index.index] - 1)
node_partner_post_embedding_label = node_post_embedding.index_select(0, partner_index.values)
merged_global_embedding_label = global_embedding.index_select(0, partner_index.index)
edge_label_input = torch.cat(
(node_current_post_embedding_label, node_partner_post_embedding_label, merged_global_embedding_label), dim=-1)
edge_label(edge_label_input)
최종적으로 partner를 예측하는 logit과 edge type을 예측하는 logit이 return된다.
return {
'edge_partner_logits': edge_partner_logits,
'edge_label_logits': edge_label_logits
}
[reference]
728x90
'Deep Learning' 카테고리의 다른 글
외부 CAD Data로 SketchGraphs 데이터셋 생성 with python (2) (0) | 2023.06.11 |
---|---|
외부 CAD Data로 SketchGraphs 데이터셋 생성 with python (1) (0) | 2023.06.04 |
SketchGraphs (0) | 2023.05.19 |
[Paper] Discovering Design Concepts for CAD Sketches (4) (0) | 2023.05.18 |
[Paper] Discovering Design Concepts for CAD Sketches (3) (0) | 2023.05.14 |