Post

XAI의 개념과 분류 | Explainable AI


설명가능한 인공지능 (XAI) 의 개념과 분류 방법을 소개합니다.

딥러닝 모델의 내부 동작을 이해하고 해석하는 방법부터 모델 해석의 신뢰성을 평가하기 위한 여러 지표를 다룹니다.

또한, 모델 해석을 검증하기 위한 각종 검증 방법과 취약점을 정리합니다.

     

Supervised (deep) Learning

  • Supervised (deep) learning has made a huge progress but deep learning models are extremely complex
    • End-to-end learning becomes a black-box
    • Problem happens when models applied to make critical decisions

     

What is Explainability & Interpretability?

  • Interpretability is the degree to which a human can understand the cause of a decision
  • Interpretability is the degree to which a human can consistently predict the model’s resutls.
  • An explanation is the answer to why-question.

     

Taxonomy of XAI Methods

  • Local vs. Global
    • Local : describes an individual prediction
    • Global : describes entire model behavior
  • White-box vs. Black-box
    • White-box : explainer can assess the inside of model
    • Black-box : explainer can assess only the output
  • Intrinsic vs. Post-hoc
    • Intrinsic : restricts the model complexity before training
    • Post-hoc : Applies after the ML model is trained
  • Model specific vs. Model agnostic
    • Model-specific : some methods restricted to specific model classes
    • Model-agnostic : some methods can be used for any model

Examples

  • Linear model, simple decision tree

➔ Global, white-box, intrinsic, model-specific

  • Grad-CAM

➔ Local, white-box, post-hoc, model-agnostic

Simple Gradient Method

  • Simple use the gradient as the explanation (importance)

    • Interpretation of f at x0 (for the i-th input/feature/pixel)
    \[R_i=(\nabla{f(x)}|_{x_0})_i\]
    • Shows how sensitive a function value is to each input
  • Examples : the gradient maps are visualized for the highest-scoring class

  • Strength : easy to compute (via back-propagation)

  • Weakness : becomes noisy (due to shattering gradient problems)

SmoothGrad

  • A simple method to address the noisy gradients

    • Add some noise to the input and average
    \[\nabla_{\text{smooth}}f(x)=\mathbb{E}_{\epsilon\sim\mathcal{N}(0,\sigma^{2}I)}[\nabla{f(x+\epsilon)}]\]
    • Averaging gradients of slightly perturbed input would smoothen the interpretation
    • Typical heuristics
      • Expectation is approximated with Monte-Carlo (around 50 runs)
      • σ is set to be 10~20% of xmax-xmin
  • Strength

    • Clearer interpretation via simple averaging
    • Applicable to most sensitive maps
  • Weakness

    • Computationally expensive

Class Activation Map (CAM)

  • Method
    • Upsample the CAM to match the size with the input image
    • Global average pooling (GAP) should be implemented before the softmax layer
  • Alternative view of CAM
    • The logit of the class c for CAM is represented by (GAP-FC model):
\[Y^c=\sum_k{w_k^c}\frac{1}{Z}\sum_{ij}{A^k_{ij}}\]
  • Result
    • CAM can localize objects in image
    • Segment the regions that have the value above 20% of the max value of the CAM and take the bounding box of it
  • Strength
    • It clearly shows what objects the model is looking at
  • Weakness
    • Model-specific: it can be applied only to models with limited architecture
    • It can only be obtained at the last convolutional layer and this makes the interpretation resolution coarse

Gradient-weighted Class Activation Mapping (Grad-CAM)

  • Method
    • To calculate the channel-wise weighted sum, Grad-CAM substitute weights by average pooled gradient
  • Strength
    • Model agnostic: can be applied to various output models
  • Weakness
    • Average gradient sometimes is not accurate
  • Result
    • Debugging the training with Grad-CAM

Local Interpretable Model-agnostic Explanations (LIME)

  • Can explain the predictions of any classfier by approximating it locally with an interpretable model
    • Model-agnostic, black-box
    • General overview of the interpretations
  • Perturb the super-pixels and obtain the local interpretation model near the given example
  • Explaining an image classification prediction made by Google’s inception neural network
  • Strength
    • Black-box interpretation
  • Weakness
    • Computationally expensive
    • Hard to apply to certain kind of models
    • When the underlying model is still locally non-linear

Randomized Input Sampling for Explanation (RISE)

  • Sub-sampling the input image via random masks
    • Record its response to each of the masked images
  • Comparison to LIME
    • The saliency of LIME is relied on super-pixels, which may not capture correct regions
  • Strength
    • Much clear saliency-map
  • Weakness
    • High computational complexity
    • Noisy due to sampling
  • RISE, sometimes, provides noisy importance maps
    • It is due to sampling approximation (Monte Carlo) expecially in presence of objects with varying sizes

     

Understanding Black-box Predictions via Influence Functions

  • Different approach for XAI
    • Identify most influential training data point for the given prediction
  • Influence function
    • Measure the effect of removing a training sample on the test loss value
  • Influence function-based explanation can show the difference between the models

     

Metrics

Human-based Visual Assessment

  • AMT (Amazon mechanical turk) test
    • Want to know: Can human predict a model prediction via interpretation?
  • Weakness
    • Obtaining human assessment is very expensive

Human Annotation

  • Some metrics employ human annotation (localization and semantic segmentation) as a ground truth, and compare them with interpretation

    • Pointing game
    • Weakly supervised semantic segmentation
  • Pointing game

    • For given human annotated bounding box \({B^i}_{i=1,...,N}\) and interpretations \(h^i_{I=1,...,N}\), a mean accuracy of pointing game is defined by:
    \[Acc=\frac{1}{N}\sum^N_{i=1}1_{[p^{(i)}\in{B^{(i)}}]}\]
    • Where \(p_i\) is a pixel s.t. \(p_i=argmax_p(h_p^i)\)
    • \(1_{[p^i∈B^i]}\) is an indicator function that the value is; if the pixel of highest interpretation score is loacted in the bounding box
  • Weakly supervised semantic segmentation

    • Setting : Pixel-level label is not given during training
    • This metric measures the mean IoU between interpretation and semantic segmentation label
  • Weakness

    • Hard to make the human annotations
    • Such localization and segmentation labels are not a ground truth of interpretation

Pixel Perturbation

  • Motivation
    • If we remove an important area in image, the logit value for class would be decreased
  • AOPC (Area over the MoRF perturbation curve)
    • AOPC measures the decreases of logits from the replacement of the input patch in MoRF (most relevant first) order
\[AOPC=\frac{1}{L+1}\mathbb{E}_{x\sim{p(x)}}[\sum^L_{k=0}f(x^{(0)}_{MoRF}-x^{(k)}_{MoRF})]\]

where f is the logit for true label

  • Insertion and deletion
    • Both measure the AUC of each curve
      • In deletion curve, x axis is the percentage of the removed pixels in the MoRF order, and y axis is the class probability of the model
      • In insertion curve, x axis is the percentage of the recovered pixels in the MoRF order, starting grom gray image
  • Weakness
    • Violates one of the key assumptions in ML that the training and evaluation data come from the same distribution
    • The perturbed input data is different from the model of interest which is deployed and explained at test time
    • Perturbation can generate another feature for model, i.e., the model tends to predict perturbed input as Balloon

ROAR

  • ROAR removes some portion of pixels in train data in the order of high interpretation values of the original model, and retrains a new model
  • Weakness
    • Retraining everytime is computationally expensive

     

Sanity Checks

Model Randomization

  • Interpretation = Edge detector?
    • Some interpretation methods produce saliency maps that strikingly similar as the one created by edge detector
  • Model randomization test
    • This experiment randomly re-initialize the parameters in a cascading fashion or single independent layer fashion
    • Some interpretation does not sensitive to this randomization, i.e., Guided-backprop, LRP, and pattern attribution

Adversarial Attack

  • Geometry is to blame
    • Proposed the adversarial attack on interpretation:
      \(\mathbb{L}=||h(x_{adv})-h^t||^2+\gamma||g(x_{adv})-g(x)||^2\)
    • Proposed a smoothing method to undo attack
      • Using a softplus activation with high beta can undo the interpretation attack
    • Provided theoretical bound of such attack
  • Results
    • In the right figure, the visualization of manipulated image is attacked with target interpretation ht.
    • For both gradient and LRP, the manipulated interpretation of for network with ReLU activation is similar as target interpretation, but the one with softplus is not manipulated.

Adversarial Model Manipulation

  • Adversarial model manipulation
    • Two models could produce totally different interpretations, while have similar accuracy.
  • Attack on the input
    • Negligible model accuracy drop
    • Fooling generalizes across validation set
    • Fooling transfers to different interpretations
    • AOPC analysis confirms true foolings

     

Reference

본 포스팅은 LG Aimers 프로그램에서 학습한 내용을 기반으로 작성되었습니다. (전체 내용 X)

  1. LG Aimers AI Essential Course Module 4. 설명가능한 AI(Explainable AI), 서울대학교 문태섭 교수
This post is licensed under CC BY 4.0 by the author.