iqa.py (2685B)
1 #!/usr/bin/env python3 2 # Copyright (c) the JPEG XL Project Authors. All rights reserved. 3 # 4 # Use of this source code is governed by a BSD-style 5 # license that can be found in the LICENSE file. 6 7 import os 8 import sys 9 import pathlib 10 import torch 11 from torchvision import transforms 12 import numpy as np 13 14 path = pathlib.Path(__file__).parent.absolute( 15 ) / '..' / '..' / '..' / 'third_party' / 'IQA-optimization' 16 sys.path.append(str(path)) 17 18 from IQA_pytorch import SSIM, MS_SSIM, CW_SSIM, GMSD, LPIPSvgg, DISTS, NLPD, FSIM, VSI, VIFs, VIF, MAD 19 20 21 # only really works with the output from JXL, but we don't need more than that. 22 def read_pfm(fname): 23 with open(fname, 'rb') as f: 24 header_width_height = [] 25 while len(header_width_height) < 3: 26 header_width_height += f.readline().rstrip().split() 27 header, width, height = header_width_height 28 assert header == b'PF' or header == b'Pf' 29 width, height = int(width), int(height) 30 scale = float(f.readline().rstrip()) 31 fmt = '<f' if scale < 0 else '>f' 32 data = np.fromfile(f, fmt) 33 if header == b'PF': 34 out = np.reshape(data, (height, width, 3))[::-1, :, :] 35 else: 36 out = np.reshape(data, (height, width))[::-1, :] 37 return out.astype(np.float) 38 39 40 D_dict = { 41 'cwssim': CW_SSIM, 42 'dists': DISTS, 43 'fsim': FSIM, 44 'gmsd': GMSD, 45 'lpips': LPIPSvgg, 46 'mad': MAD, 47 'msssim': MS_SSIM, 48 'nlpd': NLPD, 49 'ssim': SSIM, 50 'vif': VIF, 51 'vsi': VSI, 52 } 53 54 algo = os.path.basename(sys.argv[1]).split('.')[0] 55 algo, color = algo.split('-') 56 57 channels = 3 58 59 if color == 'y': 60 channels = 1 61 62 63 def Load(path): 64 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 65 transform = transforms.Compose([ 66 transforms.ToTensor(), 67 ]) 68 img = read_pfm(path) 69 if len(img.shape) == 3 and channels == 1: # rgb -> Y 70 assert img.shape[2] == 3 71 tmp = np.zeros((img.shape[0], img.shape[1], 1), dtype=float) 72 tmp[:, :, 0] = (0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 73 0.0722 * img[:, :, 2]) 74 img = tmp 75 if len(img.shape) == 2 and channels == 3: # Y -> rgb 76 gray = img 77 img = np.zeros((img.shape[0], img.shape[1], 3), dtype=float) 78 img[:, :, 0] = img[:, :, 1] = img[:, :, 2] = gray 79 if len(img.shape) == 3: 80 img = np.transpose(img, axes=(2, 0, 1)).copy() 81 return torch.FloatTensor(img).unsqueeze(0).to(device) 82 83 84 ref_img = Load(sys.argv[2]) 85 enc_img = Load(sys.argv[3]) 86 D = D_dict[algo](channels=channels) 87 score = D(ref_img, enc_img, as_loss=False) 88 89 with open(sys.argv[4], 'w') as f: 90 print(score.item(), file=f)