handlers.py•11.7 kB
"""
VitalDB MCP Server - Analysis Handlers
고급 분석 기능 구현
"""
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.signal import find_peaks
import logging
import vitaldb
from mcp.types import TextContent, ImageContent
from utils import (
load_case_cached,
compute_statistics,
evaluate_condition,
create_plot_image,
get_common_tracks
)
logger = logging.getLogger("vitaldb-handlers")
async def handle_filter_cases_by_statistics(arguments):
"""통계 조건으로 케이스 필터링"""
track_name = arguments["track_name"]
statistic = arguments["statistic"]
condition = arguments["condition"]
candidate_cases = arguments.get("candidate_cases")
max_cases = arguments.get("max_cases", 100)
interval = arguments.get("interval", 1.0)
logger.info(f"Filtering cases by {track_name} {statistic} {condition}")
# 후보 케이스 결정
if candidate_cases is None:
logger.info("Finding all cases with the track...")
candidate_cases = vitaldb.find_cases([track_name])
if len(candidate_cases) > max_cases:
candidate_cases = candidate_cases[:max_cases]
logger.info(f"Limited to {max_cases} cases")
filtered_cases = []
case_stats = {}
for case_id in candidate_cases:
try:
vals = load_case_cached(case_id, [track_name], interval)
track_data = vals[:, 0]
stats_dict = compute_statistics(track_data)
if "error" not in stats_dict:
stat_value = stats_dict[statistic]
if evaluate_condition(stat_value, condition):
filtered_cases.append(case_id)
case_stats[case_id] = stats_dict
except Exception as e:
logger.warning(f"Failed to load case {case_id}: {e}")
continue
result = {
"filter_criteria": {
"track": track_name,
"statistic": statistic,
"condition": condition
},
"total_checked": len(candidate_cases),
"matched_cases": len(filtered_cases),
"case_ids": filtered_cases,
"statistics": case_stats
}
return [TextContent(
type="text",
text=json.dumps(result, indent=2, ensure_ascii=False)
)]
async def handle_batch_analyze_cases(arguments):
"""일괄 케이스 분석"""
case_ids = arguments["case_ids"]
track_names = arguments["track_names"]
interval = arguments.get("interval", 1.0)
logger.info(f"Batch analyzing {len(case_ids)} cases")
results = {}
for case_id in case_ids:
try:
vals = load_case_cached(case_id, track_names, interval)
results[case_id] = {}
for i, track_name in enumerate(track_names):
track_data = vals[:, i]
stats_dict = compute_statistics(track_data)
results[case_id][track_name] = stats_dict
except Exception as e:
results[case_id] = {"error": str(e)}
return [TextContent(
type="text",
text=json.dumps(results, indent=2, ensure_ascii=False)
)]
async def handle_analyze_correlation(arguments):
"""상관관계 분석"""
case_ids = arguments["case_ids"]
track1 = arguments["track1"]
track2 = arguments["track2"]
interval = arguments.get("interval", 1.0)
logger.info(f"Analyzing correlation between {track1} and {track2}")
all_data1 = []
all_data2 = []
case_correlations = {}
for case_id in case_ids:
try:
vals = load_case_cached(case_id, [track1, track2], interval)
data1 = vals[:, 0]
data2 = vals[:, 1]
# NaN 제거
mask = ~(np.isnan(data1) | np.isnan(data2))
valid1 = data1[mask]
valid2 = data2[mask]
if len(valid1) > 10:
corr = np.corrcoef(valid1, valid2)[0, 1]
case_correlations[case_id] = float(corr)
all_data1.extend(valid1.tolist())
all_data2.extend(valid2.tolist())
except Exception as e:
logger.warning(f"Failed for case {case_id}: {e}")
# 전체 상관관계
overall_corr = np.corrcoef(all_data1, all_data2)[0, 1] if len(all_data1) > 0 else None
result = {
"track1": track1,
"track2": track2,
"overall_correlation": float(overall_corr) if overall_corr is not None else None,
"case_correlations": case_correlations,
"total_samples": len(all_data1)
}
return [TextContent(
type="text",
text=json.dumps(result, indent=2, ensure_ascii=False)
)]
async def handle_compare_groups(arguments):
"""그룹 비교"""
group1_cases = arguments["group1_cases"]
group2_cases = arguments["group2_cases"]
track_name = arguments["track_name"]
interval = arguments.get("interval", 1.0)
logger.info(f"Comparing groups for {track_name}")
group1_data = []
group2_data = []
for case_id in group1_cases:
try:
vals = load_case_cached(case_id, [track_name], interval)
valid_data = vals[~np.isnan(vals[:, 0]), 0]
group1_data.extend(valid_data.tolist())
except:
pass
for case_id in group2_cases:
try:
vals = load_case_cached(case_id, [track_name], interval)
valid_data = vals[~np.isnan(vals[:, 0]), 0]
group2_data.extend(valid_data.tolist())
except:
pass
# 통계 비교
group1_stats = compute_statistics(np.array(group1_data))
group2_stats = compute_statistics(np.array(group2_data))
# t-test
t_stat, p_value = stats.ttest_ind(group1_data, group2_data)
result = {
"track_name": track_name,
"group1": {
"n_cases": len(group1_cases),
"n_samples": len(group1_data),
"statistics": group1_stats
},
"group2": {
"n_cases": len(group2_cases),
"n_samples": len(group2_data),
"statistics": group2_stats
},
"statistical_test": {
"test": "independent t-test",
"t_statistic": float(t_stat),
"p_value": float(p_value),
"significant": p_value < 0.05
}
}
return [TextContent(
type="text",
text=json.dumps(result, indent=2, ensure_ascii=False)
)]
async def handle_detect_anomalies(arguments):
"""이상치 탐지"""
case_id = arguments["case_id"]
track_name = arguments["track_name"]
method = arguments.get("method", "zscore")
threshold = arguments.get("threshold", 3.0)
interval = arguments.get("interval", 0.01)
logger.info(f"Detecting anomalies in case {case_id}, {track_name}")
vals = load_case_cached(case_id, [track_name], interval)
track_data = vals[:, 0]
valid_mask = ~np.isnan(track_data)
valid_data = track_data[valid_mask]
if method == "zscore":
mean = np.mean(valid_data)
std = np.std(valid_data)
z_scores = np.abs((valid_data - mean) / std)
anomalies_idx = np.where(z_scores > threshold)[0]
elif method == "iqr":
q1 = np.percentile(valid_data, 25)
q3 = np.percentile(valid_data, 75)
iqr = q3 - q1
lower_bound = q1 - threshold * iqr
upper_bound = q3 + threshold * iqr
anomalies_idx = np.where((valid_data < lower_bound) | (valid_data > upper_bound))[0]
else:
return [TextContent(
type="text",
text=f"지원하지 않는 방법: {method}"
)]
# 원본 인덱스로 변환
valid_indices = np.where(valid_mask)[0]
anomaly_indices = valid_indices[anomalies_idx]
anomaly_values = track_data[anomaly_indices]
# 시각화
fig, ax = plt.subplots(figsize=(15, 6))
time_axis = np.arange(len(track_data)) * interval
ax.plot(time_axis, track_data, linewidth=0.5, alpha=0.7, label='Normal')
ax.scatter(time_axis[anomaly_indices], anomaly_values,
color='red', s=20, zorder=5, label='Anomalies')
ax.set_xlabel('Time (seconds)', fontsize=12)
ax.set_ylabel(track_name, fontsize=12)
ax.set_title(f'Anomaly Detection - Case {case_id}, {track_name}\n{len(anomaly_indices)} anomalies detected', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
img_base64 = create_plot_image(fig)
result = {
"case_id": case_id,
"track_name": track_name,
"method": method,
"threshold": threshold,
"total_samples": len(track_data),
"anomalies_detected": len(anomaly_indices),
"anomaly_ratio": f"{len(anomaly_indices)/len(valid_data)*100:.2f}%",
"anomaly_times": (time_axis[anomaly_indices][:20]).tolist()
}
return [
TextContent(
type="text",
text=json.dumps(result, indent=2, ensure_ascii=False)
),
ImageContent(
type="image",
data=img_base64,
mimeType="image/png"
)
]
async def handle_time_window_analysis(arguments):
"""시간 구간 분석"""
case_id = arguments["case_id"]
track_name = arguments["track_name"]
start_time = arguments["start_time"]
end_time = arguments["end_time"]
interval = arguments.get("interval", 1.0)
logger.info(f"Time window analysis: case {case_id}, {start_time}s - {end_time}s")
vals = load_case_cached(case_id, [track_name], interval)
start_idx = int(start_time / interval)
end_idx = int(end_time / interval)
end_idx = min(end_idx, vals.shape[0])
window_data = vals[start_idx:end_idx, 0]
stats_dict = compute_statistics(window_data)
result = {
"case_id": case_id,
"track_name": track_name,
"time_window": {
"start": start_time,
"end": end_time,
"duration": end_time - start_time
},
"statistics": stats_dict
}
return [TextContent(
type="text",
text=json.dumps(result, indent=2, ensure_ascii=False)
)]
async def handle_export_to_csv(arguments):
"""CSV 내보내기"""
case_ids = arguments["case_ids"]
track_names = arguments["track_names"]
output_path = arguments["output_path"]
interval = arguments.get("interval", 1.0)
logger.info(f"Exporting data to {output_path}")
# 데이터 수집
all_rows = []
for case_id in case_ids:
try:
vals = load_case_cached(case_id, track_names, interval)
for i in range(vals.shape[0]):
row = {"case_id": case_id, "time": i * interval}
for j, track_name in enumerate(track_names):
row[track_name] = vals[i, j]
all_rows.append(row)
except Exception as e:
logger.warning(f"Failed to load case {case_id}: {e}")
continue
# DataFrame 생성 및 저장
df = pd.DataFrame(all_rows)
df.to_csv(output_path, index=False)
result = {
"output_path": output_path,
"total_cases": len(case_ids),
"total_rows": len(all_rows),
"columns": list(df.columns)
}
return [TextContent(
type="text",
text=json.dumps(result, indent=2, ensure_ascii=False)
)]