visualization.py•12 kB
"""
VitalDB MCP Server - Visualization Handlers
시각화 기능 구현
"""
import asyncio
import numpy as np
import logging
from utils import load_case_cached, create_plot_image
logger = logging.getLogger("vitaldb-viz")
async def handle_plot_multiple_cases(arguments):
"""여러 케이스 플롯"""
case_ids = arguments["case_ids"]
track_name = arguments["track_name"]
interval = arguments.get("interval", 0.01)
start_time = arguments.get("start_time", 0)
duration = arguments.get("duration", 60)
overlay = arguments.get("overlay", True)
logger.info(f"Plotting {len(case_ids)} cases for {track_name}")
async def _run():
def _plot():
import os
os.environ.setdefault("MPLBACKEND", "Agg")
import numpy as _np
import matplotlib.pyplot as plt
if overlay:
fig, ax = plt.subplots(figsize=(15, 6))
for case_id in case_ids:
try:
vals = load_case_cached(case_id, [track_name], interval)
start_idx = int(start_time / interval)
end_idx = int((start_time + duration) / interval)
end_idx = min(end_idx, vals.shape[0])
vals_segment = vals[start_idx:end_idx, 0]
time_axis = _np.arange(start_idx, end_idx) * interval
ax.plot(time_axis, vals_segment, linewidth=0.8, alpha=0.7, label=f'Case {case_id}')
except Exception as e:
logger.warning(f"Failed to plot case {case_id}: {e}")
continue
ax.set_xlabel('Time (seconds)', fontsize=12)
ax.set_ylabel(track_name, fontsize=12)
ax.set_title(f'{track_name} - Multiple Cases Overlay', fontsize=14)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
ax.grid(True, alpha=0.3)
plt.tight_layout()
else:
n_cases = len(case_ids)
fig, axes = plt.subplots(n_cases, 1, figsize=(15, 3*n_cases))
if n_cases == 1:
axes = [axes]
for ax, case_id in zip(axes, case_ids):
try:
vals = load_case_cached(case_id, [track_name], interval)
start_idx = int(start_time / interval)
end_idx = int((start_time + duration) / interval)
end_idx = min(end_idx, vals.shape[0])
vals_segment = vals[start_idx:end_idx, 0]
time_axis = _np.arange(start_idx, end_idx) * interval
ax.plot(time_axis, vals_segment, linewidth=0.8)
ax.set_ylabel(f'Case {case_id}', fontsize=10)
ax.grid(True, alpha=0.3)
except Exception:
ax.text(0.5, 0.5, f'Failed to load Case {case_id}', ha='center', va='center', transform=ax.transAxes)
axes[-1].set_xlabel('Time (seconds)', fontsize=12)
fig.suptitle(f'{track_name} - Multiple Cases', fontsize=14)
plt.tight_layout()
img_base64 = create_plot_image(fig)
return f"{len(case_ids)}개 케이스의 {track_name} 플롯", img_base64
return await asyncio.to_thread(_plot)
from mcp.types import TextContent, ImageContent
text, img_base64 = await _run()
return [TextContent(type="text", text=text), ImageContent(type="image", data=img_base64, mimeType="image/png")]
async def handle_plot_distribution(arguments):
"""분포 플롯"""
case_ids = arguments["case_ids"]
track_name = arguments["track_name"]
plot_type = arguments.get("plot_type", "all")
interval = arguments.get("interval", 1.0)
logger.info(f"Plotting distribution for {track_name}")
async def _run():
def _plot():
import os
os.environ.setdefault("MPLBACKEND", "Agg")
import numpy as _np
import matplotlib.pyplot as plt
# 데이터 수집
all_data = []
for case_id in case_ids:
try:
vals = load_case_cached(case_id, [track_name], interval)
valid_data = vals[~_np.isnan(vals[:, 0]), 0]
all_data.extend(valid_data.tolist())
except Exception:
continue
all_data = _np.array(all_data)
# 플롯 생성
if plot_type == "all":
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
axes[0].hist(all_data, bins=50, edgecolor='black', alpha=0.7, color='skyblue')
axes[0].set_xlabel(track_name); axes[0].set_ylabel('Frequency'); axes[0].set_title('Histogram'); axes[0].grid(True, alpha=0.3)
axes[1].boxplot(all_data, vert=True)
axes[1].set_ylabel(track_name); axes[1].set_title('Box Plot'); axes[1].grid(True, alpha=0.3)
axes[2].violinplot([all_data], vert=True, showmeans=True, showmedians=True)
axes[2].set_ylabel(track_name); axes[2].set_title('Violin Plot'); axes[2].grid(True, alpha=0.3)
fig.suptitle(f'{track_name} Distribution ({len(case_ids)} cases)', fontsize=14)
plt.tight_layout()
elif plot_type == "histogram":
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(all_data, bins=50, edgecolor='black', alpha=0.7, color='skyblue')
ax.set_xlabel(track_name); ax.set_ylabel('Frequency'); ax.set_title(f'{track_name} Histogram'); ax.grid(True, alpha=0.3)
plt.tight_layout()
elif plot_type == "boxplot":
fig, ax = plt.subplots(figsize=(8, 6))
ax.boxplot(all_data, vert=True)
ax.set_ylabel(track_name); ax.set_title(f'{track_name} Box Plot'); ax.grid(True, alpha=0.3)
plt.tight_layout()
else: # violin
fig, ax = plt.subplots(figsize=(8, 6))
ax.violinplot([all_data], vert=True, showmeans=True, showmedians=True)
ax.set_ylabel(track_name); ax.set_title(f'{track_name} Violin Plot'); ax.grid(True, alpha=0.3)
plt.tight_layout()
img_base64 = create_plot_image(fig)
return f"{track_name}의 분포 (총 {len(all_data):,}개 샘플)", img_base64
return await asyncio.to_thread(_plot)
from mcp.types import TextContent, ImageContent
text, img_base64 = await _run()
return [TextContent(type="text", text=text), ImageContent(type="image", data=img_base64, mimeType="image/png")]
async def handle_plot_scatter_correlation(arguments):
"""산점도 상관관계"""
case_ids = arguments["case_ids"]
track1 = arguments["track1"]
track2 = arguments["track2"]
interval = arguments.get("interval", 1.0)
logger.info(f"Plotting scatter correlation between {track1} and {track2}")
async def _run():
def _plot():
import os
os.environ.setdefault("MPLBACKEND", "Agg")
import numpy as _np
import matplotlib.pyplot as plt
all_data1 = []
all_data2 = []
colors = []
color_map = plt.cm.get_cmap('tab10')
for i, case_id in enumerate(case_ids):
try:
vals = load_case_cached(case_id, [track1, track2], interval)
data1 = vals[:, 0]
data2 = vals[:, 1]
mask = ~( _np.isnan(data1) | _np.isnan(data2))
valid1 = data1[mask]
valid2 = data2[mask]
all_data1.extend(valid1.tolist())
all_data2.extend(valid2.tolist())
colors.extend([color_map(i % 10)] * len(valid1))
except Exception:
continue
all_data1_arr = _np.array(all_data1)
all_data2_arr = _np.array(all_data2)
corr = float(_np.corrcoef(all_data1_arr, all_data2_arr)[0, 1]) if len(all_data1_arr) > 1 else float('nan')
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(all_data1_arr, all_data2_arr, c=colors, alpha=0.5, s=10)
if len(all_data1_arr) > 1:
z = _np.polyfit(all_data1_arr, all_data2_arr, 1)
p = _np.poly1d(z)
ax.plot(all_data1_arr, p(all_data1_arr), "r--", linewidth=2, label='Trend line')
ax.set_xlabel(track1, fontsize=12)
ax.set_ylabel(track2, fontsize=12)
ax.set_title(f'Correlation: {track1} vs {track2}\nPearson r = {corr:.3f}', fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout()
img_base64 = create_plot_image(fig)
return f"{track1}과 {track2}의 상관관계 (r={corr:.3f}, n={len(all_data1):,})", img_base64
return await asyncio.to_thread(_plot)
from mcp.types import TextContent, ImageContent
text, img_base64 = await _run()
return [TextContent(type="text", text=text), ImageContent(type="image", data=img_base64, mimeType="image/png")]
async def handle_plot_heatmap(arguments):
"""히트맵 플롯"""
case_ids = arguments["case_ids"]
track_name = arguments["track_name"]
interval = arguments.get("interval", 1.0)
time_bins = arguments.get("time_bins", 50)
logger.info(f"Plotting heatmap for {track_name}")
async def _run():
def _plot():
import os
os.environ.setdefault("MPLBACKEND", "Agg")
import numpy as _np
import matplotlib.pyplot as plt
heatmap_data = []
for case_id in case_ids:
try:
vals = load_case_cached(case_id, [track_name], interval)
track_data = vals[:, 0]
n_samples = len(track_data)
bin_size = max(1, n_samples // time_bins)
binned_data = []
for i in range(time_bins):
start_idx = i * bin_size
end_idx = min((i + 1) * bin_size, n_samples)
if start_idx < n_samples:
segment = track_data[start_idx:end_idx]
valid_segment = segment[~_np.isnan(segment)]
if len(valid_segment) > 0:
binned_data.append(float(_np.mean(valid_segment)))
else:
binned_data.append(_np.nan)
else:
binned_data.append(_np.nan)
heatmap_data.append(binned_data)
except Exception:
continue
heatmap_array = _np.array(heatmap_data)
fig, ax = plt.subplots(figsize=(15, max(6, len(case_ids) * 0.3)))
im = ax.imshow(heatmap_array, aspect='auto', cmap='RdYlBu_r', interpolation='nearest')
ax.set_xlabel('Time Bin', fontsize=12)
ax.set_ylabel('Case', fontsize=12)
ax.set_title(f'{track_name} Heatmap Over Time', fontsize=14)
ax.set_yticks(range(len(case_ids)))
ax.set_yticklabels([f'Case {cid}' for cid in case_ids])
plt.colorbar(im, ax=ax, label=track_name)
plt.tight_layout()
img_base64 = create_plot_image(fig)
return f"{track_name}의 시간대별 히트맵 ({len(case_ids)}개 케이스)", img_base64
return await asyncio.to_thread(_plot)
from mcp.types import TextContent, ImageContent
text, img_base64 = await _run()
return [TextContent(type="text", text=text), ImageContent(type="image", data=img_base64, mimeType="image/png")]