import { URL } from "url"
// Allowlist of registry hosts that are permitted
const ALLOWED_REGISTRY_HOSTS = [
"ui.shadcn.com",
"localhost",
"host.docker.internal",
"127.0.0.1",
] as const
// Allowed protocols for registry URLs
const ALLOWED_PROTOCOLS = ["http:", "https:"] as const
// Hosts that must use HTTPS (external hosts)
const HTTPS_REQUIRED_HOSTS = ["ui.shadcn.com"] as const
// Maximum URL length to prevent DoS
const MAX_URL_LENGTH = 2048
// Timeout for HTTP requests (in milliseconds)
export const HTTP_TIMEOUT = 30000 // 30 seconds
/**
* Validates a registry URL against security policies
*/
export function validateRegistryUrl(url: string): string {
if (!url || typeof url !== "string") {
throw new Error("Registry URL must be a non-empty string")
}
// Check URL length
if (url.length > MAX_URL_LENGTH) {
throw new Error(
`Registry URL too long: ${url.length} characters (max: ${MAX_URL_LENGTH})`
)
}
// Check for dangerous characters
if (url.includes("\0") || url.includes("\n") || url.includes("\r")) {
throw new Error("Registry URL contains invalid characters")
}
let parsedUrl: URL
try {
parsedUrl = new URL(url)
} catch {
throw new Error(`Invalid registry URL format: ${url}`)
}
// Validate protocol
if (!ALLOWED_PROTOCOLS.includes(parsedUrl.protocol as any)) {
throw new Error(
`Protocol not allowed: ${parsedUrl.protocol}. Allowed: ${ALLOWED_PROTOCOLS.join(", ")}`
)
}
// Validate hostname
if (!ALLOWED_REGISTRY_HOSTS.includes(parsedUrl.hostname as any)) {
throw new Error(
`Registry host not allowed: ${parsedUrl.hostname}. Allowed: ${ALLOWED_REGISTRY_HOSTS.join(", ")}`
)
}
// Enforce HTTPS for external hosts
if (
HTTPS_REQUIRED_HOSTS.includes(parsedUrl.hostname as any) &&
parsedUrl.protocol !== "https:"
) {
throw new Error(
`HTTPS required for external registry: ${parsedUrl.hostname}`
)
}
// Validate port (if specified)
if (parsedUrl.port) {
const port = parseInt(parsedUrl.port, 10)
if (isNaN(port) || port < 1 || port > 65535) {
throw new Error(`Invalid port number: ${parsedUrl.port}`)
}
// Restrict to common web ports for security
const allowedPorts = [80, 443, 3000, 3333, 8080, 8443]
if (!allowedPorts.includes(port)) {
throw new Error(
`Port not allowed: ${port}. Allowed ports: ${allowedPorts.join(", ")}`
)
}
}
// Validate path - prevent access to sensitive endpoints
const forbiddenPaths = ["/admin", "/.env", "/config", "/system", "/private"]
for (const forbidden of forbiddenPaths) {
if (parsedUrl.pathname.startsWith(forbidden)) {
throw new Error(`Access to path not allowed: ${parsedUrl.pathname}`)
}
}
// Prevent query string injection
if (parsedUrl.search) {
// Basic validation of query parameters
const queryParams = new URLSearchParams(parsedUrl.search)
const entries = Array.from(queryParams.entries())
for (const [key, value] of entries) {
if (key.length > 100 || value.length > 1000) {
throw new Error("Query parameter too long")
}
// Check for dangerous patterns in query params
if (/[<>"'&]/.test(key) || /[<>"'&]/.test(value)) {
throw new Error("Invalid characters in query parameters")
}
}
}
// Return the validated URL
return parsedUrl.toString()
}
/**
* Validates a component URL (could be local file, URL, or registry name)
*/
export function validateComponentUrl(url: string): {
type: "local" | "url" | "registry"
value: string
} {
if (!url || typeof url !== "string") {
throw new Error("Component URL must be a non-empty string")
}
// Prevent excessively long URLs
if (url.length > MAX_URL_LENGTH) {
throw new Error(`Component URL too long: ${url.length} characters`)
}
// Check for null bytes and line breaks
if (url.includes("\0") || url.includes("\n") || url.includes("\r")) {
throw new Error("Component URL contains invalid characters")
}
// Determine URL type and validate accordingly
if (isUrl(url)) {
// External URL - validate as registry URL
const validatedUrl = validateRegistryUrl(url)
return { type: "url", value: validatedUrl }
} else if (isLocalFile(url)) {
// Local file - validate path structure
if (url.includes("..") || url.startsWith("/")) {
throw new Error("Local file path contains dangerous patterns")
}
if (!url.endsWith(".json")) {
throw new Error("Local files must be JSON files")
}
return { type: "local", value: url }
} else {
// Registry component name - validate name format
validateRegistryComponentName(url)
return { type: "registry", value: url }
}
}
/**
* Validates a registry component name
*/
function validateRegistryComponentName(name: string): void {
// Allow alphanumeric, hyphens, slashes, and certain special chars
const validNamePattern = /^[a-zA-Z0-9\-_/.@]+$/
if (!validNamePattern.test(name)) {
throw new Error(`Invalid registry component name: ${name}`)
}
// Prevent certain dangerous patterns
const dangerousPatterns = [
/\.\./, // Path traversal
/\/\//, // Double slashes
/^\./, // Starting with dot
/\/$/, // Ending with slash
]
for (const pattern of dangerousPatterns) {
if (pattern.test(name)) {
throw new Error(`Dangerous pattern in component name: ${name}`)
}
}
}
/**
* Checks if a string is a URL
*/
function isUrl(str: string): boolean {
try {
new URL(str)
return true
} catch {
return false
}
}
/**
* Checks if a string represents a local file
*/
function isLocalFile(str: string): boolean {
return str.endsWith(".json") && !isUrl(str)
}
/**
* Creates secure fetch options for HTTP requests (node-fetch compatible)
*/
export function createSecureFetchOptions(additionalOptions: any = {}): any {
return {
...additionalOptions,
// Security headers
headers: {
"User-Agent": "shadcn-mcp-server/1.0.0",
Accept: "application/json",
"Accept-Language": "en-US,en;q=0.9",
// Prevent some attacks
"X-Requested-With": "XMLHttpRequest",
...(additionalOptions.headers || {}),
},
// Security options (node-fetch compatible)
redirect: "follow", // Allow redirects
timeout: HTTP_TIMEOUT, // Use timeout option for node-fetch
// Note: some options like 'mode', 'credentials', 'cache', 'referrerPolicy' are not supported in node-fetch
}
}
/**
* Validates HTTP response for security issues (node-fetch compatible)
*/
export function validateHttpResponse(
response: any,
expectedContentType: string = "application/json"
): void {
// Check status code
if (!response.ok) {
throw new Error(
`HTTP error: ${response.status} ${response.statusText || "Unknown error"}`
)
}
// Check content type
const contentType = response.headers.get("content-type")
if (contentType && !contentType.includes(expectedContentType)) {
throw new Error(`Unexpected content type: ${contentType}`)
}
// Check content length (prevent DoS)
const contentLength = response.headers.get("content-length")
if (contentLength) {
const length = parseInt(contentLength, 10)
const maxSize = 50 * 1024 * 1024 // 50MB
if (length > maxSize) {
throw new Error(`Response too large: ${length} bytes (max: ${maxSize})`)
}
}
}
/**
* Safely parses JSON with size limits (node-fetch compatible)
*/
export async function safeJsonParse(
response: any,
maxSize: number = 10 * 1024 * 1024
): Promise<any> {
const text = await response.text()
// Check size
if (text.length > maxSize) {
throw new Error(
`JSON response too large: ${text.length} bytes (max: ${maxSize})`
)
}
// Check for potential issues
if (text.includes("\0")) {
throw new Error("JSON contains null bytes")
}
try {
return JSON.parse(text)
} catch (error) {
throw new Error(
`Invalid JSON response: ${error instanceof Error ? error.message : String(error)}`
)
}
}