Skip to content
Snippets Groups Projects
TP_Style.md 7.59 KiB
Newer Older
  • Learn to ignore specific revisions
  • Alexandre MEYER's avatar
    Alexandre MEYER committed
      
    
    
    # Partie (II) Une application pratique des réseaux : transfert de style entre images
    
      
    Ce TP vise à implémenter avec PyTorch le transfert de style d'une image à une autre en suivant un papier de Gatys etal présenté à CVPR 2016 :
    [Image Style Transfer Using Convolutional Neural Networks](https://www.cv-foundation.org/openaccess/content_cvpr_2016/html/Gatys_Image_Style_Transfer_CVPR_2016_paper.html). Ce n'est pas foncièrement un papier de deep learning et dispose de
    nombreux atouts pour un TP en image : utilisation d'un réseau pré-entrainé comme un outil, utilisation du framework de DL/PyTorch pour l'optimisation, code compact et résultats visuels et "rigolo".
    
      
    [Le code vide peut se trouver ici](https://github.com/ucacaxm/DeepLearning_Vision_SimpleExamples/blob/master/src/style_transfer/StyleTransfer_empty.py).
    
    *   [une image de contenu](https://raw.githubusercontent.com/ucacaxm/DeepLearning_Vision_SimpleExamples/master/src/style_transfer/images/montagne_small.jpg)
    *   [une image de style](https://raw.githubusercontent.com/ucacaxm/DeepLearning_Vision_SimpleExamples/master/src/style_transfer/images/peinture1_small.jpg)
    
    Le programme commence par 3 fonctions pour charger et convertir une image :
    * //load_image// pour redimensionner et normaliser avec la moyenne/écart type de VGG19; 
    * //im_convert// de conversion d'un Tensor en une image Numpy ;
    
    * //imshow// pour visualiser une image sortant de //im_convert//.
    
    Alexandre MEYER's avatar
    Alexandre MEYER committed
    
    ```
        from PIL import Image
        import matplotlib.pyplot as plt
        import numpy as np
    
        import torch
        import torch.optim as optim
        from torchvision import transforms, models
    
    
    
        def load_image(img_path, max_size=400, shape=None):
            ''' Load in and transform an image, making sure the image is <= 400 pixels in the x-y dims.'''
            
            image = Image.open(img_path).convert('RGB')
            
            # large images will slow down processing
            if max(image.size) > max_size:
                size = max_size
            else:
                size = max(image.size)
            
            if shape is not None:
                size = shape
                
            in_transform = transforms.Compose([
                                transforms.Resize(size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.485, 0.456, 0.406), 
                                                     (0.229, 0.224, 0.225))])
    
            # discard the transparent, alpha channel (that's the :3) and add the batch dimension
            image = in_transform(image)[:3,:,:].unsqueeze(0)
            
            return image
    
    
    
        # helper function for un-normalizing an image and converting it from a Tensor image to a NumPy image for display
        def im_convert(tensor):
            image = tensor.to("cpu").clone().detach()
            image = image.numpy().squeeze()
            image = image.transpose(1,2,0)
            image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
            image = image.clip(0, 1)
            return image
    
    
        def imshow(img):              # Pour afficher une image
            plt.figure(1)
            plt.imshow(img)
            plt.show()
    
    
        if __name__ == '__main__':
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            #device = torch.device("cpu")
            print(device)
    
            ########################## DISPLAY IMAGE#########################################################""
            content = load_image('images/mer.jpg').to(device)
            style = load_image('images/peinture1.jpg', shape=content.shape[-2:]).to(device)
    
            imshow(im_convert(content))
            imshow(im_convert(style))
    ```
      
      
    Nous allons réutiliser un réseau VGG déjà entrainé. [VGG est un réseau qui combine les convolutions afin d'être efficace pour de la reconnaissance d'images](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) (ImageNet Challenge). Quand nous allons optimiser le transfert de style, nous ne voulons plus optimiser les couches du réseau VGG. Ceci se réalise en passant à False le besoin en gradient des paramètres. Vous pouvez donc charger le réseaux avec PyTorch comme ceci, neutraliser les couches et afficher toutes les couches comme ceci :  
    ```
            vgg = models.vgg19(pretrained=True).features
    
            # freeze all VGG parameters since we're only optimizing the target image
            for param in vgg.parameters():
                param.requires_grad_(False)
    
            features = list(vgg)[:23]
            for i,layer in enumerate(features):
                print(i,"   ",layer)
    ```
      
      
    
    Pour récupérer les caractéristiques intermédiaires d'une image qui passe dans un réseau VGG vous pouvez le faire comme ceci :
    ```
    ### Run an image forward through a model and get the features for a set of layers. 'model' is supposed to be vgg19
    def get_features(image, model, layers=None):  
        if layers is None:
            layers = {'0': 'conv0',
                      '5': 'conv5', 
                      '10': 'conv10', 
                      '19': 'conv19',   ## content representation
                      }
            
        features = {}
        x = image
        # model._modules is a dictionary holding each module in the model
        for name, layer in model._modules.items():
            x = layer(x)
            if name in layers:
                features[layers[name]] = x
                
        return features
    ```
      
      
    Nous allons maintenant créer l'image cible qui va être une copie de l'image de contenu et dont les pixels seront à optimiser :
    ```
            target = content.clone().requires_grad_(True).to(device)
    ```
      
      
    
    Vous devez écrire la fonction *gram_matrix* qui calcule la [matrice de Gram](https://en.wikipedia.org/wiki/Gramian_matrix) à partir d'un
    tensor. Vous pouvez regarder la documentation de la fonction [torch.mm](https://pytorch.org/docs/stable/torch.html#torch.mm) qui
    multiplie deux matrices, et la fonction [torch.transpose](https://pytorch.org/docs/stable/torch.html#torch.transpose).
    La fonction [torch.Tensor.view](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view) permet de changer la "vue" pour par exemple passer d'un tenseur 2D à un tenseur 1D, ou d'un 3D vers un 2D, etc.
    ```
        def gram_matrix(tensor):
           # tensor: Nfeatures x H x W ==> M = Nfeatures x Npixels with Npixel=HxW
           ...
           return gram
    ```
      
      
    Écrivez le calcul de coût pour le contenu. Vous pouvez utiliser \[\[\|torch.mean\]\] avec les features extraits de la couche 'conv19'
    qui d'après l'article correspondent globalement au contenu. Attention, les noms de couches ne correspondent pas à l'article.
    
      
    Écrivez le calcul du coût pour le style. Il va se calculer de la même manière mais vous allez itérer sur les features des autres couches. A tester un peu par essai/erreur (ou regardez l'article).
    
      
    Le coût total (celui qui sera optimisé) se calcule en faisant la moyenne pondérée entre le coût de style et le coût de contenu. A tester un peu par essai/erreur (ou regardez l'article).
    
      
    La partie optimisation va donc ressembler à ceci.
    ```
            optimizer = optim.Adam([target], lr=0.003)
            for i in range(50):
            
                # get the features from your target image
            
                # the content loss
            
                # the style loss
                
                # calculate the *total* loss
            
                # update your target image
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
    ```
      
      
    
    ## Pour aller plus loin sur le transfert de style entre images
    
    [Un blog qui décrit bien les évolutions de la recherche après l'approche de Gatys](https://dudeperf3ct.github.io/style/transfer/2018/12/23/Magic-of-Style-Transfer/). Donne des explications également autour des approches de normalisation, AdaIN.