2 분 소요

“FedNI: Federated Graph Learning with Network Inpainting for Population-Based Disease Prediction”이란 논문에 대한 리뷰입니다.

원문은 링크에서 확인할 수 있습니다.

Background

기존엔 GCN으로 Feature extraction진행 후 Classifier 진행 본 방식에선 GCN과 Classifier 모두 Federated 방식으로 학습 진행 -Optimization 방식이 Fed인 것이고 각각이 학습하는 건 task에 맞춰 정해진 NI

Part 1: Network Inpainting

: hidden node와 edge를 찾는 것

기존의 GCN은 graph에 대한 데이터가 complete하다는 가정을 하지만 실제론 incomplete하니 incomplete data 환경을 만들고 complete하게 예측하는 것을 학습해야 한다.

  1. Hidden node를 정하기 BFS 방식으로 sub-tree를 만들고, leaf node를 hidden node를 정하게 된다.

  2. Open node로 feature extraction 진행 GCN으로 open node의 embedding 생성 : Generator
    중요한 것은 GCN에서 앞의 encoder만 사용하고 목적에 따라 following layer 정의

  3. Hidden node 수 예측 open node의 embedding으로 hidden node의 개수를 예측하는 MLP

  4. Hidden node feature 예측 Open node의 embedding으로 hidden node의 embedding을 예측하는 MLP

  5. Hidden node feature의 다른 의미의 reconstruction loss 정의 기존의 reconstruction loss의 경우 Input vs decoder 차이를 줄이지만, Hidden node의 실제 embedding과 open node의 embedding으로 예측한 것(decoder의 결과)의 차이를 줄인다

  6. 그리고 예측된 hidden embedding의 분포와 실제 hidden embedding의 분포 차이를 줄임 SN-GAN의 Loss를 이용하기 위해 discriminator를 사용하고 결과적으로 discriminator를 만들고 이를 가지고 generator loss 정의
    • 즉 Generator의 loss(얼마나 hidden feature를 예측하는가)는 GAN으로 학습된 discriminator 계산
    • 최종적으로 Reconstruction loss와 Generator loss의 weighted sum
  7. 더 잘 예측하기 위해 다른 데이터(non-image data)를 사용 Phenotype data를 예측하는 모델 생성

Part 2: Federated Learning

NI as Fed

여기서 달라지는 것은 local과 global로 나뉘게 된다.
앞서 NI를 통해 local들의 Generator와 Discriminator를 학습하고, 이걸 Federate방식으로 aggregation step이 들어가는데 단 Discriminator는 Local한 상태로 둔다(NO FED)
이건 그냥 본인의 graph와 유사한지 안유사한지 구분하는 것이지 다른 graph의 information이 필요한 게 아니라 local graph의 heterogeneity를 알아야한다. 사실 이건 feature를 정확히 만드는 것이 중요한 건데 결국 이건 local에서 embedding을 어떻게 정교하게 만드는지를 다루는 것이니 global이 필요없게 되는 것이다.
Generator의 경우 centralized환경과 비교했을 때 다른 partition과 연결될 수 있는 node를 찾아야하기 때문에 global information이 필요하게 되어 Fed가 된다. 이건 궁극적으로 ML task를 할 때 사용되는 데이터이니 global distribution을 따르게 만드는 것.

  • 이 과정을 통해 하나의 전체 graph를 예측하는 과정이 끝난다.

Classifier as Fed

이렇게 모인 것을 또 classifier를 학습할 때에도 Fed를 쓴다.

  • 이건 server가 classifier를 하는 것이 아니라 local에서 각각 classifier를 학습하기 위한 것
  • 즉, Fed는 결국 서버간에 정보공유 없이 일을 수행하되 align시키는 server-server schema
  • 우리가 하는 것은 서버 아래에 client들간 align 시키는 server-client schema

Part 3: Application for our study

여기서 다루는 건 1) “서버급” edge들이 어떻게 align된 성능을 낼 것인지와 2) ML task에 맞춘 edge modeling

우리가 적용할 수 있는 부분이라면, NI가 끝나고 graph merge하는 단계에서 하나의 통일된 전체 graph를 만들 때 align시키는 것, 그리고 서버급 edge 아래에 client간 data를 align시키는 것

  1. Generator를 FedAvg하는 것이 아니라 pseudo embedding으로 알아서 조정되게 만드는 방식 (개인적인 고민 1: 기존의 방식이라면 Reconstruction loss 뒤에 term을 추가해야 하는데, 이건 hidden node의 예측값을 align하는 것. 그렇다면 Open node부터 align되어야 할텐데 align한다면 그건 GCN을 학습할 때 해야하는 것인가? 즉 한번만 align시키면 될까?)

  2. 전체 분포를 사용하는 건 discriminator를 학습할 때도 사용되는데 이 때에도 align시키는 내용이 들어가야 하는가? 즉, (환자-병원) – (병원-환자) 급으로 진행하는 schema를 상상할 수 있지 않을까

댓글남기기