Skip to content
Snippets Groups Projects

Partie (II) Une application pratique des réseaux : transfert de style entre images

Ce TP vise à implémenter avec PyTorch le transfert de style d'une image à une autre en suivant un papier de Gatys etal présenté à CVPR 2016 : Image Style Transfer Using Convolutional Neural Networks. Ce n'est pas foncièrement un papier de deep learning et dispose de nombreux atouts pour un TP en image : utilisation d'un réseau pré-entrainé comme un outil, utilisation du framework de DL/PyTorch pour l'optimisation, code compact et résultats visuels et "rigolo".

Le code vide peut se trouver ici.

Le programme commence par 3 fonctions pour charger et convertir une image :

  • //load_image// pour redimensionner et normaliser avec la moyenne/écart type de VGG19;
  • //im_convert// de conversion d'un Tensor en une image Numpy ;
  • //imshow// pour visualiser une image sortant de //im_convert//.
    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np

    import torch
    import torch.optim as optim
    from torchvision import transforms, models



    def load_image(img_path, max_size=400, shape=None):
        ''' Load in and transform an image, making sure the image is <= 400 pixels in the x-y dims.'''
        
        image = Image.open(img_path).convert('RGB')
        
        # large images will slow down processing
        if max(image.size) > max_size:
            size = max_size
        else:
            size = max(image.size)
        
        if shape is not None:
            size = shape
            
        in_transform = transforms.Compose([
                            transforms.Resize(size),
                            transforms.ToTensor(),
                            transforms.Normalize((0.485, 0.456, 0.406), 
                                                 (0.229, 0.224, 0.225))])

        # discard the transparent, alpha channel (that's the :3) and add the batch dimension
        image = in_transform(image)[:3,:,:].unsqueeze(0)
        
        return image



    # helper function for un-normalizing an image and converting it from a Tensor image to a NumPy image for display
    def im_convert(tensor):
        image = tensor.to("cpu").clone().detach()
        image = image.numpy().squeeze()
        image = image.transpose(1,2,0)
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
        image = image.clip(0, 1)
        return image


    def imshow(img):              # Pour afficher une image
        plt.figure(1)
        plt.imshow(img)
        plt.show()


    if __name__ == '__main__':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #device = torch.device("cpu")
        print(device)

        ########################## DISPLAY IMAGE#########################################################""
        content = load_image('images/mer.jpg').to(device)
        style = load_image('images/peinture1.jpg', shape=content.shape[-2:]).to(device)

        imshow(im_convert(content))
        imshow(im_convert(style))

Nous allons réutiliser un réseau VGG déjà entrainé. VGG est un réseau qui combine les convolutions afin d'être efficace pour de la reconnaissance d'images (ImageNet Challenge). Quand nous allons optimiser le transfert de style, nous ne voulons plus optimiser les couches du réseau VGG. Ceci se réalise en passant à False le besoin en gradient des paramètres. Vous pouvez donc charger le réseaux avec PyTorch comme ceci, neutraliser les couches et afficher toutes les couches comme ceci :

        vgg = models.vgg19(pretrained=True).features

        # freeze all VGG parameters since we're only optimizing the target image
        for param in vgg.parameters():
            param.requires_grad_(False)

        features = list(vgg)[:23]
        for i,layer in enumerate(features):
            print(i,"   ",layer)

Pour récupérer les caractéristiques intermédiaires d'une image qui passe dans un réseau VGG vous pouvez le faire comme ceci :

### Run an image forward through a model and get the features for a set of layers. 'model' is supposed to be vgg19
def get_features(image, model, layers=None):  
    if layers is None:
        layers = {'0': 'conv0',
                  '5': 'conv5', 
                  '10': 'conv10', 
                  '19': 'conv19',   ## content representation
                  }
        
    features = {}
    x = image
    # model._modules is a dictionary holding each module in the model
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features

Nous allons maintenant créer l'image cible qui va être une copie de l'image de contenu et dont les pixels seront à optimiser :

        target = content.clone().requires_grad_(True).to(device)

Vous devez écrire la fonction gram_matrix qui calcule la matrice de Gram à partir d'un tensor. Vous pouvez regarder la documentation de la fonction torch.mm qui multiplie deux matrices, et la fonction torch.transpose. La fonction torch.Tensor.view permet de changer la "vue" pour par exemple passer d'un tenseur 2D à un tenseur 1D, ou d'un 3D vers un 2D, etc.

    def gram_matrix(tensor):
       # tensor: Nfeatures x H x W ==> M = Nfeatures x Npixels with Npixel=HxW
       ...
       return gram

Écrivez le calcul de coût pour le contenu. Vous pouvez utiliser [[|torch.mean]] avec les features extraits de la couche 'conv19' qui d'après l'article correspondent globalement au contenu. Attention, les noms de couches ne correspondent pas à l'article.

Écrivez le calcul du coût pour le style. Il va se calculer de la même manière mais vous allez itérer sur les features des autres couches. A tester un peu par essai/erreur (ou regardez l'article).

Le coût total (celui qui sera optimisé) se calcule en faisant la moyenne pondérée entre le coût de style et le coût de contenu. A tester un peu par essai/erreur (ou regardez l'article).

La partie optimisation va donc ressembler à ceci.

        optimizer = optim.Adam([target], lr=0.003)
        for i in range(50):
        
            # get the features from your target image
        
            # the content loss
        
            # the style loss
            
            # calculate the *total* loss
        
            # update your target image
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

Pour aller plus loin sur le transfert de style entre images

Un blog qui décrit bien les évolutions de la recherche après l'approche de Gatys. Donne des explications également autour des approches de normalisation, AdaIN.