Blogs

Revolutionizing Image Enhancement: A Deep Dive into SRGAN


Welcome to the cutting-edge of image enhancement technology! In recent years, the field of computer vision has seen remarkable advancements, and one such breakthrough is the Super-Resolution Generative Adversarial Network (SRGAN). In this blog post, we'll explore the ins and outs of SRGAN, its significance in the world of image processing, and how it is transforming the way we perceive and enhance visual content.

What is it?

SRGAN, short for Super-Resolution Generative Adversarial Network, is a deep learning model specifically designed for image super-resolution. Its primary goal is to generate high-resolution images from low-resolution inputs, a task that was previously considered challenging.

What's so special about it?

SRGAN goes beyond traditional image upscaling methods by leveraging deep learning techniques. It doesn't just increase the pixel count but also adds realistic details to the images, resulting in visually appealing and more natural-looking high-resolution content.

Architecture of SRGAN

his formulation's fundamental idea is to develop a generative model G to trick a differentiable discriminator D that has been trained to discern between actual and super-resolved pictures.
By using this method, the generator may be trained to produce results that are very close to authentic photos, making it challenging for D to classify them.The training of D and G is done using the min-max problem.

Create your own Super-Resolution GAN

Create a virtual environment and let’s begin the coding

Let’s start by importing some required libraries

#Import Libraries
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torch
import math
from os import listdir
import numpy as np
from torch.autograd import Variable
from torch import nn, optim
from torchvision.models.vgg import vgg16
from tqdm import tqdm
import os

Creating constants and important functions used

# Enable anomaly detection in PyTorch autograd for debugging
torch.autograd.set_detect_anomaly(True)
# Constants for image processing
UPSCALE_FACTOR = 4
CROP_SIZE = 88
# Mean and standard deviation values for image normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

Now, I will load in some code for the dataset and dataloaders.

# Check if the given filename has an image file extension
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

# Calculate a valid crop size based on the upscale factor
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)

# Define a transformation for high-resolution training images
def train_hr_transform(crop_size):
 return Compose([
RandomCrop(crop_size),
ToTensor(),
])

# Define a transformation for low-resolution training images
def train_lr_transform(crop_size, upscale_factor):
return Compose([
ToPILImage(),
Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
ToTensor()
])

# Define a transformation for displaying images
def display_transform():
return Compose([
ToPILImage(),
Resize(400),
CenterCrop(400),
ToTensor()# Convert the image to a PyTorch tensor
])

class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
# Get the list of image filenames in the dataset directory
self.image_filenames = [join(dataset_dir, x) for x in                   
listdir(dataset_dir) if is_image_file(x)]
crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
self.hr_transform = train_hr_transform(crop_size)
self.lr_transform = train_lr_transform(crop_size,upscale_factor)

# Load and apply high-resolution and low-resolution transformations to the images
def __getitem__(self, index):
hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
lr_image = self.lr_transform(hr_image)
return lr_image, hr_image

# Return the number of images in the dataset
def __len__(self):
return len(self.image_filenames)

Now let’s load the trainset

# Creating an instance of the TrainDatasetFromFolder class
train_set = TrainDatasetFromFolder("Path to your HR Dataset", crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)

# Creating a DataLoader for training set
trainloader = DataLoader(train_set, batch_size=32, num_workers=4, shuffle=True)

Load the Generator Architecture

GENERATOR network G

The residual blocks (B=16) were first generated by ResNet. Two convolutional layers with tiny 3x3 kernels and 64 feature maps are utilised within the residual block. Batch-normalization layers and the activation function ParametricReLU are employed after that.

class Generator(nn.Module):
def __init__(self, scale_factor):
super(Generator, self).__init__()

# Determine the number of upsample blocks based on the scale factor
upsample_block_num = int(math.log(scale_factor, 2))

# Initial convolutional block
self.block1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=9, padding=4),
nn.PReLU()
)

# Residual blocks
self.block2 = ResidualBlock(64)
self.block3 = ResidualBlock(64)
self.block4 = ResidualBlock(64)
self.block5 = ResidualBlock(64)
self.block6 = ResidualBlock(64)

# Additional convolutional block
self.block7 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64)
)

# Upsample blocks
block8 = [UpsampleBlock(64, 2) for _ in range(upsample_block_num)]
block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
self.block8 = nn.Sequential(*block8)
def forward(self, x):

#Initial block
block1 = self.block1(x)

# Residual blocks
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
block6 = self.block6(block5)
block7 = self.block7(block6)
block8 = self.block8(block1 + block7)

# Apply tanh activation and normalization to get the final output
return (torch.tanh(block8) + 1) / 2

Two learned sub-pixel convolution layers raise the input image's resolution.



Load the Discriminator Architecture

DISCRIMINATOR network D

Activate LeakyReLU (α=0.2) and prevent max-pooling across the network.
The maximization problem is trained onto the discriminator network.
Eight convolutional layers make up the network, and when the number of 3×3 filter kernels increases—from 64 to 512—it does so by a factor of two, much like in the VGGnetwork.
Every time the number of features doubles, the picture resolution is decreased using stepped convolutions.
To determine the likelihood of classifying a sample, two dense layers and a final sigmoid activation function are used after the generated 512 feature maps.

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

# Sequential network consisting of convolutional and activation layers
self.net = nn.Sequential nn.Conv2d(3, 64, kernel_size=3, padding=1), # Input convolutional layer with 3 channels, output 64 channels
nn.LeakyReLU(0.2), # Leaky ReLU activation with a negative slope of 0.2

nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), # Convolutional layer downsampling by stride 2
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),

nn.Conv2d(64, 128, kernel_size=3, padding=1), # Convolutional layer with 128 output channels
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),

nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),

nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),

nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),

nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),

nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1024, kernel_size=1),
nn.LeakyReLU(0.2),
nn.Conv2d(1024, 1, kernel_size=1)
)

def forward(self, x):
batch_size=x.size()[0]
return torch.sigmoid(self.net(x).view(batch_size))

Implement the loss functions

The TVLoss
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()

# Initialize with the provided TV loss weight
self.tv_loss_weight=tv_loss_weight
def forward(self, x):
batch_size=x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]

# Calculate the number of elements in height and width dimensions
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])

# Calculate TV loss along height and width
h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, :h_x - 1, :], 2).sum()
w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, :w_x - 1], 2).sum()

# Return the total variation loss
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

@staticmethod
def tensor_size(t):
# Helper method to calculate the total number of elements in tensor
return t.size()[1] * t.size()[2] * t.size()[3]
The Generator Loss
class GeneratorLoss(nn.Module):
def __init__(self):
super(GeneratorLoss, self).__init__()

# Load VGG16 pretrained model and extract features until layer 31
vgg = vgg16(pretrained=True)
loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()

# Set requires_grad to False for all parameters in the loss network
for param in loss_network.parameters():
param.requires_grad = False

# Save the modified VGG model as the loss network
self.loss_network = loss_network

# Define Mean Squared Error (MSE) loss
self.mse_loss = nn.MSELoss()

# Create an instance of TVLoss for Total Variation loss
self.tv_loss = TVLoss()

def forward(self, out_labels, out_images, target_images):

# Adversarial loss: mean of 1 - out_labels (maximizing out_labels for fake images)
adversial_loss = torch.mean(1 - out_labels)

# Perception loss: MSE loss between generated images and target images
perception_loss = self.mse_loss(out_images, target_images)

 # Image loss: MSE loss between generated images and target images
image_loss = self.mse_loss(out_images, target_images)

# Total Variation (TV) loss to encourage spatial smoothness
tv_loss = self.tv_loss(out_images)

# Combine all loss components with specified weights
return image_loss + 0.001 * adversial_loss + 0.006 * perception_loss + 2e-8 * tv_loss

Now we implement the residual block here

# Now we will start implementing the model.
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()

# First convolutional layer with kernel_size=3 and padding=1
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
# Forward pass through the residual block
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
return x + residual

Up-sampling the Neural Network Module

class UpsampleBlock(nn.Module):
def __init__(self, in_channels, up_scale):
super(UpsampleBlock, self).__init__()

# Convolutional layer to increase the number of channels
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)

# PixelShuffle layer for upscaling
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = nn.PReLU()
def forward(self, x):

# Forward pass through the UpsampleBlock
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x

Standard device selection

Implementation of the function and class created above

Set the optimizer and learning rate

Generate results

# Check if CUDA (GPU) is available and set device accordingly
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the Generator and Discriminator models
netG = Generator(UPSCALE_FACTOR)
netD = Discriminator()

# Instantiate the GeneratorLoss criterion for training the Generator
generator_criterion = GeneratorLoss()

 
# Move models and loss criterion to the selected device (CPU or GPU)
generator_criterion = generator_criterion.to(device)
netG = netG.to(device)
netD = netD.to(device)

 
# Set up Adam optimizers for Generator and Discriminator
optimizerG = optim.Adam(netG.parameters(), lr=0.001)
optimizerD = optim.Adam(netD.parameters(), lr=0.001)

results = {
"d_loss":[],
"g_loss":[],
"d_score": [],
"g_score": []
}
print(results)

Set the desired Epoch number as per the size of the dataset , for General I use 150

N_EPOCHS = 150

Start the training

for epoch in range(1, N_EPOCHS + 1):
train_bar = tqdm(trainloader)
running_results = {'batch_sizes':0, 'd_loss':0,
"g_loss":0, "d_score":0, "g_score":0}

netG.train()
netD.train()
for data, target in train_bar:
g_update_first = True
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size
real_img = Variable(target)
real_img = real_img.to(device)
z = Variable(data)
z = z.to(device)

## Update Discriminator ##
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph = True)
optimizerD.step()

## Now update Generator
fake_img = netG(z)
fake_out = netD(fake_img).mean()
netG.zero_grad()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()

fake_img = netG(z)
fake_out = netD(fake_img).mean()

optimizerG.step()
running_results['g_loss'] += g_loss.item() * batch_size
running_results['d_loss'] += d_loss.item() * batch_size
running_results['d_score'] += real_out.item() * batch_size
running_results['g_score'] += fake_out.item() * batch_size

## Updating the progress bar
train_bar.set_description(desc="[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f" % (
epoch, N_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
running_results['g_loss'] / running_results['batch_sizes'],
running_results['d_score'] / running_results['batch_sizes'],
running_results['g_score'] / running_results['batch_sizes']
))
netG.eval()
netG.train()

Save the Generator and Discriminator model seperately

import torch

# Assuming you have trained your models and stored them in 'netG' and 'netD'
# Modify the path accordingly
epoch_number = 150.1
generator_save_path = 'folder path/generator_model_epoch_{}.pth'.format(epoch_number)
discriminator_save_path = 'folder path/discriminator_model_epoch_{}.pth'.format(epoch_number)
                          
# Save the generator and discriminator models
torch.save(netG.state_dict(), generator_save_path)
torch.save(netD.state_dict(), discriminator_save_path)

Applications and Practical Uses

Super-Resolution Generative Adversarial Networks (SRGAN) has found applications in various domains due to its ability to enhance image resolution while preserving and generating realistic details. Here are some notable applications of SRGAN:

Photography

High-Resolution Image Up-scaling: Photographers can use SRGAN to enhance the resolution of low-resolution images without losing image quality, allowing for better print quality or display on larger screens

Video Enhancement

Upgrading Video Quality: SRGAN can be applied to enhance the quality of low-resolution video frames, improving the visual experience of video content. This is particularly useful for restoring old or degraded videos.

Medical Imaging

Enhancing Medical Imagery: In medical imaging, SRGAN can be utilized to improve the resolution of images obtained through various diagnostic techniques, aiding in more accurate analysis and diagnosis

Satellite Imagery

Improving Satellite Image Resolution: SRGAN can enhance the resolution of satellite images, providing clearer and more detailed views of landscapes, urban areas, and environmental features.

Art and Graphics Design

Creating High-Resolution Artwork: Artists and graphic designers can use SRGAN to increase the resolution of digital artwork, ensuring that the final output is crisp and detailed.

Surveillance and Security

Enhancing Surveillance Footage: Security systems and surveillance cameras can benefit from SRGAN by improving the quality of captured images, which can be crucial for identifying details in forensic analysis.

Virtual Reality (VR) and Augmented Reality (AR)

Enhancing Visual Realism: SRGAN can be employed to enhance the visual quality of images used in VR and AR applications, providing users with more immersive and realistic experiences.

Remote Sensing

High-Resolution Remote Sensing: SRGAN can be applied to enhance the resolution of remote sensing data, such as images captured by drones or other aerial platforms, improving the accuracy of data analysis in fields like agriculture and environmental monitoring.

Facial Recognition

Improving Facial Detail: In facial recognition systems, SRGAN can contribute to better face recognition accuracy by generating high-resolution facial images from lower-resolution inputs.

Consumer Electronics

TV and Display Enhancement: SRGAN can be integrated into consumer electronics, such as smart TVs and displays, to upscale lower-resolution content for a more enjoyable and immersive viewing experience.

Challenges

Computational Complexity

Training SRGAN models demands substantial computational resources and time. The complexity of the model architecture, especially in the context of adversarial training, can be a hindrance for practical implementation.

Memory Requirements

Large-scale deep learning models like SRGAN require substantial memory, making it challenging for deployment on devices with limited resources, such as smartphones or embedded systems.

Training Dataset Quality:

The performance of SRGAN depends on the quality and diversity of the training dataset. Incomplete or biased datasets may lead to artifacts or inaccuracies in the generated high-resolution images.

Real-Time Applications

Achieving real-time performance remains a challenge, particularly for applications such as video streaming or live video processing. Optimizing the computational efficiency of SRGAN for these scenarios is an ongoing area of research.

Generalization to Diverse Scenes

SRGAN models trained on specific types of images may struggle to generalize well to diverse scenes or unconventional inputs. Ensuring robust performance across a wide range of content types is an active research challenge.

Development and Future Fore-sighting

Optimizing Training Algorithms

Researchers are working on developing more efficient training algorithms for SRGAN to reduce the computational burden and speed up the training process. This includes exploring novel optimization techniques and model architectures.

Memory-Efficient Models

To address memory constraints, there is ongoing research to design more memory-efficient SRGAN variants that can be deployed on devices with limited resources, expanding the practical applications of the technology.

Transfer Learning and Domain Adaptation

Investigating transfer learning and domain adaptation techniques allows SRGAN models to generalize better to new or unseen types of images. This can enhance the model's applicability in diverse real-world scenarios.

Real-Time Processing

Efforts are being made to develop lightweight SRGAN models capable of real-time image and video processing. This involves optimizing architectures and algorithms for faster inference without compromising on quality.

Addressing Artifacts and Image Quality

Ongoing research focuses on refining SRGAN models to reduce artifacts and improve the perceptual quality of generated images. This includes incorporating human perceptual factors into the training process and exploring advanced loss functions.

User-Friendly Implementations

Developing user-friendly implementations and tools for SRGAN facilitates wider adoption. This involves creating accessible software libraries, frameworks, or applications that make it easier for practitioners to apply SRGAN to their specific use cases.

Conclusion

As we delve into the era of SRGAN and witness its trans-formative impact on image enhancement, it's clear that we are at the forefront of a technological revolution. The marriage of deep learning and image processing has given birth to a powerful tool that not only increases resolution but elevates the visual experience to new heights. Keep an eye on SRGAN as it continues to evolve, promising a future where the line between reality and digital creation becomes increasingly blurred.

References

  • Bell, S., Zitnick, C. L., Bala, K., Girshick, R. (2016). Inside-outside net: Detecting objects in context with skip pooling and recurrent neural networks. In: IEEE Conference on Computer Vision and Pattern Recognition.
  • Dong, C., Loy, C. C., He, K., Tang, X. (2014) Learning a deep convolutional network for image super-resolution. In: European Conference on Computer Vision
  • Shi, W., Caballero, J., Huszár, F., Totz, J., Aitken, A. P., Bishop, R., Rueckert, D., Wang, Z. (2016) Real-time single image and video super-resolution using an efficient sub-pixel convolutional neural network. In: IEEE Conference on Computer Vision and Pattern Recognition