libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

iqa.py (2685B)


      1 #!/usr/bin/env python3
      2 # Copyright (c) the JPEG XL Project Authors. All rights reserved.
      3 #
      4 # Use of this source code is governed by a BSD-style
      5 # license that can be found in the LICENSE file.
      6 
      7 import os
      8 import sys
      9 import pathlib
     10 import torch
     11 from torchvision import transforms
     12 import numpy as np
     13 
     14 path = pathlib.Path(__file__).parent.absolute(
     15 ) / '..' / '..' / '..' / 'third_party' / 'IQA-optimization'
     16 sys.path.append(str(path))
     17 
     18 from IQA_pytorch import SSIM, MS_SSIM, CW_SSIM, GMSD, LPIPSvgg, DISTS, NLPD, FSIM, VSI, VIFs, VIF, MAD
     19 
     20 
     21 # only really works with the output from JXL, but we don't need more than that.
     22 def read_pfm(fname):
     23     with open(fname, 'rb') as f:
     24         header_width_height = []
     25         while len(header_width_height) < 3:
     26             header_width_height += f.readline().rstrip().split()
     27         header, width, height = header_width_height
     28         assert header == b'PF' or header == b'Pf'
     29         width, height = int(width), int(height)
     30         scale = float(f.readline().rstrip())
     31         fmt = '<f' if scale < 0 else '>f'
     32         data = np.fromfile(f, fmt)
     33         if header == b'PF':
     34             out = np.reshape(data, (height, width, 3))[::-1, :, :]
     35         else:
     36             out = np.reshape(data, (height, width))[::-1, :]
     37         return out.astype(np.float)
     38 
     39 
     40 D_dict = {
     41     'cwssim': CW_SSIM,
     42     'dists': DISTS,
     43     'fsim': FSIM,
     44     'gmsd': GMSD,
     45     'lpips': LPIPSvgg,
     46     'mad': MAD,
     47     'msssim': MS_SSIM,
     48     'nlpd': NLPD,
     49     'ssim': SSIM,
     50     'vif': VIF,
     51     'vsi': VSI,
     52 }
     53 
     54 algo = os.path.basename(sys.argv[1]).split('.')[0]
     55 algo, color = algo.split('-')
     56 
     57 channels = 3
     58 
     59 if color == 'y':
     60     channels = 1
     61 
     62 
     63 def Load(path):
     64     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     65     transform = transforms.Compose([
     66         transforms.ToTensor(),
     67     ])
     68     img = read_pfm(path)
     69     if len(img.shape) == 3 and channels == 1:  # rgb -> Y
     70         assert img.shape[2] == 3
     71         tmp = np.zeros((img.shape[0], img.shape[1], 1), dtype=float)
     72         tmp[:, :, 0] = (0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] +
     73                         0.0722 * img[:, :, 2])
     74         img = tmp
     75     if len(img.shape) == 2 and channels == 3:  # Y -> rgb
     76         gray = img
     77         img = np.zeros((img.shape[0], img.shape[1], 3), dtype=float)
     78         img[:, :, 0] = img[:, :, 1] = img[:, :, 2] = gray
     79     if len(img.shape) == 3:
     80         img = np.transpose(img, axes=(2, 0, 1)).copy()
     81     return torch.FloatTensor(img).unsqueeze(0).to(device)
     82 
     83 
     84 ref_img = Load(sys.argv[2])
     85 enc_img = Load(sys.argv[3])
     86 D = D_dict[algo](channels=channels)
     87 score = D(ref_img, enc_img, as_loss=False)
     88 
     89 with open(sys.argv[4], 'w') as f:
     90     print(score.item(), file=f)