To learn a new task as we humans need not always start afresh but rather apply previously learned knowledge. In the same way, “Transfer Learning” (TL) allows a machine learning model to port the knowledge acquired during the training of one task to a new task. TL is mostly used to obtain well-performing machine learning models in settings where high-quality labeled data is scarce.
How to determine whether a model trained in a source domain (with abundant data) may be adapted to a target domain (with scarce data or no labeled data)? This is termed a domain or task-relatedness challenge. Knowledge about the source and target domains can help determine domain relatedness. Moreover, there are certain mathematical measures like maximum mean discrepancy (MMD) which estimate the domain relatedness by mapping the feature space in a Hilbert Space with a non-linear map.
TL has been applied traditionally in many ML models and is recently being used for Deep Learning models. Following are a few popular types of TL:
1. Feature Based transfer: Different layers in a deep neural network capture different set of features. The intermediate layers of the source domain model can be used as a feature extractor and an ML classifier can be trained on the small target domain data with features extracted by the source model. For example, consider an image classification task of categorizing an image into three classes on lighting exposure: over-exposed, underexposed, or normal. Suppose we have very few annotated images. Taking any of the SOTA CNN models like ImageNet and using it as a feature extractor we can train a simple ML model over these features.
Now, which layers of the network can be the best feature extractor? As the target task is to classify images based on brightness/contrast, the lower-level layers of the CNN which represent color, texture, and brightness filters may be a good choice. Also, global pooling along channel dimension for these filters can reduce feature dimension.
There are two subtypes of feature-based transfer:
- Symmetric (Finds a common latent feature space representation for both source and target domain: domain adaptation)
- Asymmetric (Transforms the features of the source to match more closely to the target domain: domain confusion, or make the target domain features match with source features such that the source classifier can be used as is: adversarial domain adaptation)
2. Instance-Based Transfer: Instances from the source domain are reweighted in an attempt to correct for marginal distribution differences. These reweighted instances are then directly used in the target domain for training along with the limited available data from the target domain. The idea is to decrease the weights of the instances that have negative effects on the target learner.
3. Parameter Transfer: A parametric model is learned in the source domain and transferred knowledge is encoded into parameters. If we have limited labeled target data and multiple labeled source domains, we can train separate classifiers for the source domains and optimally combine the re-weighted learners like ensemble learners to form an improved target learner. A neural network is a parametric model and using a model trained in the source domain as a starting point and fine-tuning the final layers of the model for the target domain task can be termed as a parameter transfer.
We will be discussing each of these transfer learning variants in the tutorial session along with hands-on code and cover some latest applications of TL:
- Privacy-Preserving ML: Train ML models without exposing private data in its original form.
- Multimodal Deep Learning: Learn features over multiple modalities (audio, video, text, structured)
About the author/ODSC West speaker:
Tamoghna Ghosh is an AI Solution Architect in the Client Computing Group at Intel, working on building next-generation AI solutions for edge computing. Prior to this role, he worked as a data scientist at Intel working on various domains like supply chain – inventory optimization, anomaly detection, and failure prediction of various IT infrastructures across Intel, building advanced search tools for bug sightings, to name a few.