import plotly.graph_objects as go
import numpy as np
# Load positional embeddings and compute cosine similarity matrix
pos_embed = np.load('pos_embed_video.npz')['pos_embed_video'][0] # Shape: (512, 128)
pos_embed_norm = pos_embed / np.linalg.norm(pos_embed, axis=1, keepdims=True)
cos_sim = pos_embed_norm @ pos_embed_norm.T # Shape: (512, 512)
time_size = 8
grid_size = 8
render_threshold = 0.4
def get_cube_data(ref_w, ref_h, ref_t):
"""Generate cube vertex data for a given reference position."""
# Index formula: t * (H*W) + h * W + w -> reshape gives (T, H, W) order
ref_idx = ref_t * (grid_size * grid_size) + ref_h * grid_size + ref_w
similarities = cos_sim[ref_idx].reshape((time_size, grid_size, grid_size))
val_min, val_max = render_threshold, 1.0
xs, ys, zs, colors, hovers = [], [], [], [], []
for t in range(time_size):
for h in range(grid_size):
for w in range(grid_size):
value = similarities[t, h, w] # (T, H, W) order from reshape
if value <= render_threshold:
continue
normalized = min(1.0, max(0.0, (value - val_min) / (val_max - val_min)))
opacity = 0.3 + 0.7 * normalized # Range: 0.3 to 1.0
r = int(180 - 130 * normalized)
g = int(60 + 180 * normalized)
b = int(220 - 20 * normalized)
xs.append(w)
ys.append(h)
zs.append(t)
colors.append(f'rgba({r}, {g}, {b}, {opacity:.2f})')
hovers.append(f'(w={w}, h={h}, t={t})<br>Similarity: {value:.3f}')
return xs, ys, zs, colors, hovers
# Build frames for all 512 positions (8 * 16 * 16)
frames = []
for ref_t in range(time_size):
for ref_h in range(grid_size):
for ref_w in range(grid_size):
xs, ys, zs, colors, hovers = get_cube_data(ref_w, ref_h, ref_t)
# Use scatter3d with markers instead of mesh for performance
frame_data = [
go.Scatter3d(
x=xs, y=ys, z=zs,
mode='markers',
marker=dict(
size=12,
color=colors,
symbol='square',
),
text=hovers,
hovertemplate='%{text}<extra></extra>',
showlegend=False
),
# Reference marker
go.Scatter3d(
x=[ref_w], y=[ref_h], z=[ref_t],
mode='markers',
marker=dict(size=8, color='gold', symbol='diamond',
line=dict(color='darkgoldenrod', width=2)),
hovertemplate=f'Reference (w={ref_w}, h={ref_h}, t={ref_t})<extra></extra>',
showlegend=False
)
]
frames.append(go.Frame(data=frame_data, name=f'{ref_w}_{ref_h}_{ref_t}'))
# Initial view
init_w, init_h, init_t = 4, 5, 3
xs, ys, zs, colors, hovers = get_cube_data(init_w, init_h, init_t)
fig = go.Figure(
data=[
go.Scatter3d(
x=xs, y=ys, z=zs,
mode='markers',
marker=dict(size=12, color=colors, symbol='square'),
text=hovers,
hovertemplate='%{text}<extra></extra>',
showlegend=False
),
go.Scatter3d(
x=[init_w], y=[init_h], z=[init_t],
mode='markers',
marker=dict(size=8, color='gold', symbol='diamond',
line=dict(color='darkgoldenrod', width=2)),
name='Reference'
),
# Colorbar reference
go.Scatter3d(
x=[None], y=[None], z=[None], mode='markers',
marker=dict(size=0.1, color=[0],
colorscale=[[0, 'rgb(180, 60, 220)'], [1, 'rgb(50, 240, 200)']],
cmin=render_threshold, cmax=1.0,
colorbar=dict(title='Sim', thickness=12, len=0.5)),
showlegend=False, hoverinfo='skip'
)
],
frames=frames
)
# Create three sliders - one for each dimension
fig.update_layout(
title=dict(text="Positional Embedding Similarities", x=0.5),
scene=dict(
xaxis_title='W (width)',
yaxis_title='H (height)',
zaxis_title='T (time)',
xaxis=dict(tickvals=list(range(grid_size)), range=[-0.5, grid_size-0.5]),
yaxis=dict(tickvals=list(range(grid_size)), range=[-0.5, grid_size-0.5]),
zaxis=dict(tickvals=list(range(time_size)), range=[-0.5, time_size-0.5]),
aspectmode='cube',
camera=dict(eye=dict(x=1.6, y=1.6, z=1.0))
),
sliders=[
dict(
active=init_w, currentvalue={"prefix": "W: ", "font": {"size": 14}},
pad={"t": 40}, len=0.25, x=0.05, xanchor="left",
steps=[dict(args=[[f'{w}_{init_h}_{init_t}'], {"frame": {"duration": 0}, "mode": "immediate"}],
label=str(w), method="animate") for w in range(grid_size)]
),
dict(
active=init_h, currentvalue={"prefix": "H: ", "font": {"size": 14}},
pad={"t": 40}, len=0.25, x=0.38, xanchor="left",
steps=[dict(args=[[f'{init_w}_{h}_{init_t}'], {"frame": {"duration": 0}, "mode": "immediate"}],
label=str(h), method="animate") for h in range(grid_size)]
),
dict(
active=init_t, currentvalue={"prefix": "T: ", "font": {"size": 14}},
pad={"t": 40}, len=0.25, x=0.71, xanchor="left",
steps=[dict(args=[[f'{init_w}_{init_h}_{t}'], {"frame": {"duration": 0}, "mode": "immediate"}],
label=str(t), method="animate") for t in range(time_size)]
),
],
margin=dict(l=0, r=0, t=50, b=200),
legend=dict(x=0.85, y=0.95)
)
fig.show()