Skip to content

arturml/pytorch-wgan-gp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pytorch WGAN-GP

This is a pytorch implementation of Improved Training of Wasserstein GANs. Most of the code was inspired by this repository by EmilienDupont.

Training

To train on the MNIST dataset, run

python main.py --dataset mnist --epochs 200

For the FashionMNIST dataset, run

python main.py --dataset fashion --epochs 200

You cans also set up a generator and discriminator pair and use the WGANGP class:

wgan = WGANGP(generator, discriminator,
              g_optimizer, d_optimizer,
              latent_shape, dataset_name)
wgan.train(data_loader, n_epochs)

The argument latent_shape is the shape whatever the generator's forward function accepts as input.

The training process is monitored by tensorboardX.

Results

Here is the training history for both datasets:

MNIST losses fashion losses

Two gifs of the training process:

MNIST training gif fashion training gif

Interpolation in latent space

We can generate samples going smoothly from one class to another by interpolating points on the latent space (done in this notebook):

MNIST interpolation fashion interpolation

The weights of the models are on the saved_models folder.

About

A pytorch implementation of WGAN-GP

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published