vllm_provider.py•4.92 kB
#!/usr/bin/env python3
"""
vLLM Provider for LangExtract
vLLM 서버와 통신하는 커스텀 Provider 클래스
"""
import os
import time
import requests
from typing import List, Optional
from urllib3.exceptions import InsecureRequestWarning
import langextract as lx
from langextract.core import base_model as lx_base_model
from langextract.core import types as lx_types
# urllib3의 InsecureRequestWarning 경고를 비활성화
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
def strip_think_block(text: str) -> str:
"""</think>를 기준으로 문자열을 나누고, 그 뒷부분만 반환합니다."""
if "</think>" in text:
parts = text.split("</think>", 1)
return parts[1].lstrip()
return text
class VLLMProvider(lx_base_model.BaseLanguageModel):
"""vLLM 서버와 통신하고, <think> 블록을 처리하는 커스텀 Provider."""
def __init__(self, model_id: str, **kwargs):
super().__init__(model_id=model_id, **kwargs)
self.model_id = model_id
self.base_url = os.getenv("VLLM_BASE_URL", "https://qwen.smartmind.team/v1")
self.url = self.base_url.rstrip("/") + "/chat/completions"
self.api_key = kwargs.get("api_key", os.getenv("VLLM_API_KEY", os.getenv("OPENAI_API_KEY", "dummy")))
self.timeout = kwargs.get("timeout", 600)
self.temperature = kwargs.get("temperature", 0.0)
self.max_tokens = kwargs.get("max_tokens", 4096)
def infer(self, batch_prompts: List[str], **kwargs) -> List[List[lx_types.ScoredOutput]]:
"""vLLM 서버에 요청을 보내고 응답을 처리합니다."""
results = []
for prompt in batch_prompts:
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model_id,
"messages": [{"role": "user", "content": prompt}],
"temperature": self.temperature,
"max_tokens": self.max_tokens
}
# vLLM 서버에 요청
response = requests.post(
self.url,
json=payload,
headers=headers,
timeout=self.timeout,
verify=False
)
response.raise_for_status()
# 응답 처리
raw_content = response.json()["choices"][0]["message"]["content"]
cleaned_content = strip_think_block(raw_content)
# ScoredOutput 객체 생성
scored_outputs = [lx_types.ScoredOutput(score=1.0, output=cleaned_content)]
results.append(scored_outputs)
except Exception as e:
# 오류 발생 시 빈 결과 반환
print(f"vLLM 요청 실패: {e}")
scored_outputs = [lx_types.ScoredOutput(score=0.0, output="")]
results.append(scored_outputs)
return results
class ModelProviderFactory:
"""모델 Provider를 생성하는 팩토리 클래스"""
@staticmethod
def create_provider(provider_type: str, model_id: str, **kwargs):
"""
Provider 타입에 따라 적절한 Provider를 생성합니다.
Args:
provider_type: "openai" 또는 "vllm"
model_id: 모델 ID
**kwargs: 추가 설정
Returns:
Provider 인스턴스
"""
if provider_type.lower() == "vllm":
return VLLMProvider(model_id=model_id, **kwargs)
elif provider_type.lower() == "openai":
# OpenAI는 langextract 내장 Provider 사용
return None # None이면 langextract의 기본 OpenAI Provider 사용
else:
raise ValueError(f"지원하지 않는 Provider 타입: {provider_type}")
def get_provider_config():
"""환경변수에서 Provider 설정을 읽어옵니다."""
provider_type = os.getenv("PII_PROVIDER", "openai").lower()
# Provider별 기본 모델 ID 설정
if provider_type == "openai":
default_model = "gpt-4o"
else: # vllm
default_model = "qwen3-235b-awq"
model_id = os.getenv("PII_MODEL_ID", default_model)
# Provider별 API 키 설정
if provider_type == "openai":
api_key = os.getenv("OPENAI_API_KEY")
else: # vllm
api_key = os.getenv("VLLM_API_KEY", "dummy")
return {
"provider_type": provider_type,
"model_id": model_id,
"api_key": api_key,
"base_url": os.getenv("VLLM_BASE_URL", "https://qwen.smartmind.team/v1") if provider_type == "vllm" else None
}