import os
import subprocess
import hashlib
import flyte.io # noqa: F401 - imported to register FileTransformer and DirTransformer with TypeEngine
import flyte.remote
import uuid
from grpc.aio import AioRpcError
from dataclasses import dataclass
from flyte._utils import asyncify
@dataclass
class RunResult:
stdout: str
stderr: str
returncode: int
next_step: str
async def run_task(
name: str,
inputs: dict,
version: str | None = None,
) -> flyte.remote.ActionDetails:
task = flyte.remote.Task.get(
name=name,
version=version,
auto_version="latest" if version is None else None,
)
run: flyte.remote.Run = flyte.run(task, **inputs)
return await run.action.details()
async def get_task(
name: str,
version: str | None = None,
) -> flyte.remote.Task:
return flyte.remote.Task.get(
name=name,
version=version,
auto_version="latest" if version is None else None,
).fetch()
async def get_run_details(name: str) -> flyte.remote.ActionDetails:
run = flyte.remote.Run.get(name=name)
return await run.action.details()
async def wait_for_run_completion(name: str) -> flyte.remote.ActionDetails:
run = flyte.remote.Run.get(name=name)
await run.wait.aio()
return await run.action.details()
async def get_run_io(
name: str,
) -> tuple[flyte.remote.ActionInputs, flyte.remote.ActionOutputs]:
run: flyte.remote.Run = flyte.remote.Run.get(name=name)
return run.inputs(), run.outputs()
async def list_tasks() -> list[flyte.remote.Task]:
tasks = []
for task in flyte.remote.Task.listall():
tasks.append(await get_task(task.name))
return tasks
async def list_runs(task_name: str | None = None) -> list[flyte.remote.Run]:
runs = []
for run in flyte.remote.Run.listall(
task_name=task_name,
limit=10,
sort_by=("created_at", "desc"),
):
runs.append(run)
return runs
async def _get_or_create_api_key_secret(value: str) -> flyte.Secret:
name = hashlib.sha256(value.encode()).hexdigest()[:8]
name = f"union-mcp-key-{name}"
try:
await flyte.remote.Secret.get.aio(name=name)
except AioRpcError:
await flyte.remote.Secret.create.aio(name=name, value=value)
return flyte.Secret(key=name, as_env_var="FLYTE_PASSTHROUGH_API_KEY")
async def build_script_image(script: str, api_key: str, tail: int = 100) -> RunResult:
"""Build the container image for a Flyte script using a pre-deployed task.
This function invokes the `build_script_image_task` on the remote Flyte cluster
to build the container image for the provided script.
Args:
script: The Python script content to build.
Returns:
A RunResult containing the stdout, stderr, returncode and next_step.
"""
api_key_secret = await _get_or_create_api_key_secret(api_key)
task = flyte.remote.Task.get(
name="union_mcp_tasks.build_image",
version=os.environ["APP_TASK_VERSION"],
).override(secrets=[api_key_secret])
run = flyte.run(task, script=script, tail=tail)
await run.wait.aio()
return run.outputs()[0]
async def run_script_remote(script: str, api_key: str, tail: int = 100) -> RunResult:
"""Run a Flyte script remotely using a pre-deployed task.
This function invokes the `run_script_remote_task` on the remote Flyte cluster
to execute the provided script.
Args:
script: The Python script content to run.
Returns:
A RunResult containing the stdout, stderr, returncode and next_step.
"""
api_key_secret = await _get_or_create_api_key_secret(api_key)
task = flyte.remote.Task.get(
name="union_mcp_tasks.run_task",
version=os.environ["APP_TASK_VERSION"],
).override(secrets=[api_key_secret])
run = flyte.run(task, script=script, tail=tail)
await run.wait.aio()
return run.outputs()[0]
async def build_script_image_(script: str, tail: int = 200) -> dict:
"""This is an internal function used by a Flyte task to build the container image for a script."""
filename = f"__build_script_{str(uuid.uuid4())[:8]}__.py".replace("-", "_")
with open(filename, "w") as f:
f.write(script)
try:
proc = await asyncify.run_sync_with_loop(
subprocess.run,
["/opt/venv/bin/python", filename, "--build"],
capture_output=True,
env=os.environ, # Use clean environment
text=True,
)
# Include full stderr on error
full_stderr = proc.stderr if proc.returncode != 0 else "\n".join(proc.stderr.splitlines()[-tail:])
return RunResult(
stdout="\n".join(proc.stdout.splitlines()[-tail:]),
stderr=full_stderr,
returncode=proc.returncode,
next_step="if the image build is successful, run the script with the run_script_remote tool. if the image build fails, check the run details for the build run and debug the issue.",
)
finally:
os.remove(filename)
async def run_script_remote_(script: str, tail: int = 200) -> dict:
"""This is an internal function used by a Flyte task to run a script on the remote Flyte cluster."""
filename = f"__run_script_{str(uuid.uuid4())[:16]}__.py".replace("-", "_")
with open(filename, "w") as f:
f.write(script)
try:
proc = await asyncify.run_sync_with_loop(
subprocess.run,
["/opt/venv/bin/python", filename],
capture_output=True,
env=os.environ, # Use clean environment
text=True,
)
# Include full stderr to see the complete error message
full_stderr = proc.stderr if proc.returncode != 0 else "\n".join(proc.stderr.splitlines()[-tail:])
return RunResult(
stdout="\n".join(proc.stdout.splitlines()[-tail:]),
stderr=full_stderr,
returncode=proc.returncode,
next_step="if the script run is successful, use the get_run_io tool to get the inputs and outputs of the run. if the script run fails, check the run details for the run and debug the issue.",
)
finally:
os.remove(filename)
async def search_flyte_examples(
pattern: str, file_or_dir: str, top_n: int = 3, before_context_lines: int = 5, after_context_lines: int = 5,
) -> str:
"""Grep for a pattern in flyte-sdk/examples, return top n files with most matches as markdown.
Args:
pattern: The pattern to search for.
file_or_dir: The directory or file to search in. Defaults to "flyte-sdk/examples".
top_n: The number of top files to return. Defaults to 3.
context_lines: The number of lines to show before and after each match. Defaults to 5.
Returns:
A markdown-formatted string containing the matching lines with context from the top files.
"""
# Use grep -c to count matches per file
proc = await asyncify.run_sync_with_loop(
subprocess.run,
["grep", "-r", "-c", pattern, file_or_dir],
capture_output=True,
text=True,
)
if proc.returncode not in (0, 1): # 1 means no matches found
return f"Error running grep: {proc.stderr}"
if not proc.stdout.strip():
return f"No matches found for pattern: {pattern}"
# Parse output: each line is "filename:count"
file_counts: list[tuple[str, int]] = []
for line in proc.stdout.strip().split("\n"):
if ":" in line:
# Handle case where filename might contain colons
parts = line.rsplit(":", 1)
if len(parts) == 2:
filepath, count_str = parts
try:
count = int(count_str)
if count > 0: # Only include files with matches
file_counts.append((filepath, count))
except ValueError:
continue
if not file_counts:
return f"No matches found for pattern: {pattern}"
# Sort by count descending and take top_n
file_counts.sort(key=lambda x: x[1], reverse=True)
top_files = file_counts[:top_n]
# Build markdown output
markdown_parts = [f"# Top {len(top_files)} files matching pattern: `{pattern}`\n"]
for filepath, count in top_files:
markdown_parts.append(f"## `{filepath}` ({count} matches)\n")
# Get matching lines with context using grep -B and -A
try:
context_proc = await asyncify.run_sync_with_loop(
subprocess.run,
["grep", "-n", f"-B{before_context_lines}", f"-A{after_context_lines}", pattern, filepath],
capture_output=True,
text=True,
)
if context_proc.returncode == 0 and context_proc.stdout.strip():
# Determine language for syntax highlighting
ext = os.path.splitext(filepath)[1].lstrip(".")
lang = ext if ext else "text"
markdown_parts.append(f"```{lang}\n{context_proc.stdout.strip()}\n```\n")
else:
markdown_parts.append("*No context available for matches*\n")
except (IOError, OSError) as e:
markdown_parts.append(f"*Error getting context: {e}*\n")
return "\n".join(markdown_parts)
def script_format() -> str:
return """
```python
# /// script
# dependencies = [
# "flyte>=2.0.0b49", # IMPORTANT: it makes sure the script can be run on the MCP server
# <package-name>
# ...
# ]
# ///
import flyte
# IMPORTANT: only import flyte packages and python standard library packages
# Import 3rd party dependencies inside the task functions or helper functions
# ... other imports ...
# Define the task environment
env = flyte.TaskEnvironment(
name="<task-env-name>",
resources=flyte.Resources(cpu=<cpu-count>, memory="<memory-size>", gpu="<gpu-name>:<gpu-count>", disk="<disk-size>"),
image=flyte.Image.from_uv_script(__file__, name="<image-name>", python_version=(<python-major-version>, <python-minor-version>), pre=True)
)
# Define one or more tasks.
@env.task
async def <task-name>(<task-arguments>) -> <task-return-type>:
import <package-name>
<task-body>
# Define helper functions as needed
async def <helper-function-name>(<helper-function-arguments>) -> <helper-function-return-type>:
import <other-package-name>
<helper-function-body>
# more tasks
...
@env.task
async def main(<main-arguments>) -> <main-return-type>: # the main task is the entry point for the script
<main-body>
if __name__ == "__main__":
import argparse
import os
from flyte.remote import auth_metadata
# IMPORTANT: it makes sure the script can be both built and run on the MCP server
parser = argparse.ArgumentParser()
parser.add_argument("--build", action="store_true")
args = parser.parse_args()
# IMPORTANT: it makes sure the script can be run on the MCP server
flyte.init_passthrough(
project=os.getenv("FLYTE_INTERNAL_EXECUTION_PROJECT"),
domain=os.getenv("FLYTE_INTERNAL_EXECUTION_DOMAIN"),
)
with auth_metadata(("authorization", os.environ["FLYTE_PASSTHROUGH_API_KEY"])):
if args.build:
uri = flyte.build(env.image, wait=False)
print(f"build run url: {{uri}}")
else:
run = flyte.with_runcontext(mode="remote").run(main, <main-arguments>)
print(run.url)
```
""".strip()
def script_example() -> str:
"""Get a full example of a Flyte script."""
return """
```python
# /// script
# dependencies = [
# "flyte>=2.0.0b49", # THIS IS IMPORTANT: it makes sure the script can be run on the MCP server
# "scikit-learn==1.6.1",
# "pandas",
# "pyarrow",
# "joblib",
# ]
# ///
import flyte
import flyte.io
env = flyte.TaskEnvironment(
name="my_example_script",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="example-image", python_version=(3, 13), pre=True)
)
@env.task
async def create_dataset(n_samples: int = 100) -> flyte.io.DataFrame:
import pandas as pd
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=n_samples, n_features=10, n_classes=2)
df = pd.DataFrame(X)
df["target"] = y
fdf = flyte.io.DataFrame.from_df(df)
return fdf
@env.task
async def train_model(dataset: flyte.io.DataFrame) -> flyte.io.File:
from sklearn.ensemble import RandomForestClassifier
import joblib
model = RandomForestClassifier()
model.fit(dataset.drop(columns=["target"]), dataset["target"])
file = flyte.io.File.new_remote()
with open(file.path, "wb") as f:
joblib.dump(model, f)
return file
return model
@env.task
async def main() -> flyte.io.File:
dataset = await create_dataset()
model = await train_model(dataset)
return model
if __name__ == "__main__":
import argparse
import os
from flyte.remote import auth_metadata
# THIS IS IMPORTANT: it makes sure the script can be both built and run on the MCP server
parser = argparse.ArgumentParser()
parser.add_argument("--build", action="store_true")
args = parser.parse_args()
# THIS IS IMPORTANT: it makes sure the script can be run on the MCP server
flyte.init_passthrough(
project=os.getenv("FLYTE_INTERNAL_EXECUTION_PROJECT"),
domain=os.getenv("FLYTE_INTERNAL_EXECUTION_DOMAIN"),
)
with auth_metadata(("authorization", os.environ["FLYTE_PASSTHROUGH_API_KEY"])):
if args.build:
uri = flyte.build(env.image, wait=False)
print(f"build run url: {{uri}}")
else:
run = flyte.with_runcontext(mode="remote").run(main)
print(run.url)
```
""".strip()