prompt.rs•5.24 kB
use std::{borrow::Cow, sync::Arc};
use futures::future::BoxFuture;
use crate::{
handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext},
model::{GetPromptResult, Prompt},
};
pub struct PromptRoute<S> {
#[allow(clippy::type_complexity)]
pub get: Arc<DynGetPromptHandler<S>>,
pub attr: crate::model::Prompt,
}
impl<S> std::fmt::Debug for PromptRoute<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PromptRoute")
.field("name", &self.attr.name)
.field("description", &self.attr.description)
.field("arguments", &self.attr.arguments)
.finish()
}
}
impl<S> Clone for PromptRoute<S> {
fn clone(&self) -> Self {
Self {
get: self.get.clone(),
attr: self.attr.clone(),
}
}
}
impl<S: Send + Sync + 'static> PromptRoute<S> {
pub fn new<H, A: 'static>(attr: impl Into<Prompt>, handler: H) -> Self
where
H: GetPromptHandler<S, A> + Send + Sync + Clone + 'static,
{
Self {
get: Arc::new(move |context: PromptContext<S>| {
let handler = handler.clone();
handler.handle(context)
}),
attr: attr.into(),
}
}
pub fn new_dyn<H>(attr: impl Into<Prompt>, handler: H) -> Self
where
H: for<'a> Fn(
PromptContext<'a, S>,
) -> BoxFuture<'a, Result<GetPromptResult, crate::ErrorData>>
+ Send
+ Sync
+ 'static,
{
Self {
get: Arc::new(handler),
attr: attr.into(),
}
}
pub fn name(&self) -> &str {
&self.attr.name
}
}
pub trait IntoPromptRoute<S, A> {
fn into_prompt_route(self) -> PromptRoute<S>;
}
impl<S, H, A, P> IntoPromptRoute<S, A> for (P, H)
where
S: Send + Sync + 'static,
A: 'static,
H: GetPromptHandler<S, A> + Send + Sync + Clone + 'static,
P: Into<Prompt>,
{
fn into_prompt_route(self) -> PromptRoute<S> {
PromptRoute::new(self.0.into(), self.1)
}
}
impl<S> IntoPromptRoute<S, ()> for PromptRoute<S>
where
S: Send + Sync + 'static,
{
fn into_prompt_route(self) -> PromptRoute<S> {
self
}
}
/// Adapter for functions generated by the #\[prompt\] macro
pub struct PromptAttrGenerateFunctionAdapter;
impl<S, F> IntoPromptRoute<S, PromptAttrGenerateFunctionAdapter> for F
where
S: Send + Sync + 'static,
F: Fn() -> PromptRoute<S>,
{
fn into_prompt_route(self) -> PromptRoute<S> {
(self)()
}
}
#[derive(Debug)]
pub struct PromptRouter<S> {
#[allow(clippy::type_complexity)]
pub map: std::collections::HashMap<Cow<'static, str>, PromptRoute<S>>,
}
impl<S> Default for PromptRouter<S> {
fn default() -> Self {
Self {
map: std::collections::HashMap::new(),
}
}
}
impl<S> Clone for PromptRouter<S> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
impl<S> IntoIterator for PromptRouter<S> {
type Item = PromptRoute<S>;
type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, PromptRoute<S>>;
fn into_iter(self) -> Self::IntoIter {
self.map.into_values()
}
}
impl<S> PromptRouter<S>
where
S: Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
map: std::collections::HashMap::new(),
}
}
pub fn with_route<R, A: 'static>(mut self, route: R) -> Self
where
R: IntoPromptRoute<S, A>,
{
self.add_route(route.into_prompt_route());
self
}
pub fn add_route(&mut self, item: PromptRoute<S>) {
self.map.insert(item.attr.name.clone().into(), item);
}
pub fn merge(&mut self, other: PromptRouter<S>) {
for item in other.map.into_values() {
self.add_route(item);
}
}
pub fn remove_route(&mut self, name: &str) {
self.map.remove(name);
}
pub fn has_route(&self, name: &str) -> bool {
self.map.contains_key(name)
}
pub async fn get_prompt(
&self,
context: PromptContext<'_, S>,
) -> Result<GetPromptResult, crate::ErrorData> {
let item = self.map.get(context.name.as_str()).ok_or_else(|| {
crate::ErrorData::invalid_params(
format!("prompt '{}' not found", context.name),
Some(serde_json::json!({
"available_prompts": self.list_all().iter().map(|p| &p.name).collect::<Vec<_>>()
})),
)
})?;
(item.get)(context).await
}
pub fn list_all(&self) -> Vec<crate::model::Prompt> {
self.map.values().map(|item| item.attr.clone()).collect()
}
}
impl<S> std::ops::Add<PromptRouter<S>> for PromptRouter<S>
where
S: Send + Sync + 'static,
{
type Output = Self;
fn add(mut self, other: PromptRouter<S>) -> Self::Output {
self.merge(other);
self
}
}
impl<S> std::ops::AddAssign<PromptRouter<S>> for PromptRouter<S>
where
S: Send + Sync + 'static,
{
fn add_assign(&mut self, other: PromptRouter<S>) {
self.merge(other);
}
}