1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
|
# warp 部分代码摘自 ARFlow(https://github.com/lliuz/ARFlow)
import torch
import torch.nn as nn
import inspect
import numpy as np
import cv2
import Imath
import array
import OpenEXR
import matplotlib.pyplot as plt
def mesh_grid(B, H, W):
# mesh grid
x_base = torch.arange(0, W).repeat(B, H, 1) # BHW
y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW
base_grid = torch.stack([x_base, y_base], 1) # B2HW
return base_grid
def norm_grid(v_grid):
_, _, H, W = v_grid.size()
# scale grid to [-1,1]
v_grid_norm = torch.zeros_like(v_grid)
v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0
v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0
return v_grid_norm.permute(0, 2, 3, 1) # BHW2
"""
:param data: unnormalized coordinates Bx2xHxW
:return: Bx1xHxW
"""
B, _, H, W = data.size()
# x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W)
# y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1)
x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W)
y = data[:, 1, :, :].view(B, -1)
# invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN
# invalid = invalid.repeat([1, 4])
x1 = torch.floor(x)
x_floor = x1.clamp(0, W - 1)
y1 = torch.floor(y)
y_floor = y1.clamp(0, H - 1)
x0 = x1 + 1
x_ceil = x0.clamp(0, W - 1)
y0 = y1 + 1
y_ceil = y0.clamp(0, H - 1)
x_ceil_out = x0 != x_ceil
y_ceil_out = y0 != y_ceil
x_floor_out = x1 != x_floor
y_floor_out = y1 != y_floor
invalid = torch.cat([x_ceil_out | y_ceil_out,
x_ceil_out | y_floor_out,
x_floor_out | y_ceil_out,
x_floor_out | y_floor_out], dim=1)
# encode coordinates, since the scatter function can only index along one axis
corresponding_map = torch.zeros(B, H * W).type_as(data)
indices = torch.cat([x_ceil + y_ceil * W,
x_ceil + y_floor * W,
x_floor + y_ceil * W,
x_floor + y_floor * W], 1).long() # BxN (N=4*H*W)
values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)),
(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)),
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)),
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))],
1)
# values = torch.ones_like(values)
values[invalid] = 0
corresponding_map.scatter_add_(1, indices, values)
# decode coordinates
corresponding_map = corresponding_map.view(B, H, W)
return corresponding_map.unsqueeze(1)
def flow_warp(x, flow12, pad='border', mode='bilinear'):
B, _, H, W = x.size()
base_grid = mesh_grid(B, H, W).type_as(x) # B2HW
v_grid = norm_grid(base_grid + flow12) # BHW2
if 'align_corners' in inspect.getfullargspec(torch.nn.functional.grid_sample).args:
im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad, align_corners=True)
else:
im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad)
return im1_recons
def warp_vis(cur, nxt, flow, pad='border', mode='bilinear'):
"""
cur, nxt, flow: ndarray, H * W * C
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cur = torch.from_numpy(cur).double()[None].permute(0, 3, 1, 2).to(device)
nxt = torch.from_numpy(nxt).double()[None].permute(0, 3, 1, 2).to(device)
flow = torch.from_numpy(flow).double()[None].permute(0, 3, 1, 2).to(device)
nxt_warp = flow_warp(nxt, flow, pad, mode)
nxt_warp = nxt_warp[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
cur = cur[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
nxt = nxt[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
fig, axs = plt.subplots(1, 3)
axs[0].imshow(cur)
axs[0].set_title('cur')
axs[1].imshow(nxt)
axs[1].set_title('nxt')
axs[2].imshow(nxt_warp)
axs[2].set_title('nxt_warp')
plt.show()
def exr2flow(exr_path, w, h):
file = OpenEXR.InputFile(exr_path)
FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
(R,G,B,A) = [array.array('f', file.channel(Chan, FLOAT)).tolist() for Chan in ("R", "G", "B","A") ]
flow = np.zeros((h,w,4), np.float64)
flow[:,:,0] = np.array(R).reshape(flow.shape[0],-1)
flow[:,:,1] = -np.array(G).reshape(flow.shape[0],-1)
flow[:,:,2] = np.array(B).reshape(flow.shape[0],-1)
flow[:,:,3] = -np.array(A).reshape(flow.shape[0],-1)
return flow
if __name__ == '__main__':
im1, im2, im3 = cv2.imread('0001.png'), cv2.imread('0002.png'), cv2.imread('0003.png')
h, w = im1.shape[:2]
flow = exr2flow('Flow0002.exr', w, h)
warp_vis(im2, im1, flow[:,:,:2]) # R,G
warp_vis(im2, im3, -flow[:,:,2:]) # B,A
|