openai.go•2.33 kB
package ai
import (
"errors"
"net/http"
"net/url"
"github.com/sashabaranov/go-openai"
"k8s.io/klog/v2"
)
const openAIClientName = "openai"
type OpenAIClient struct {
nopCloser
client *openai.Client
model string
temperature float32
topP float32
tools []openai.Tool
maxHistory int32
memory *memoryService
// organizationId string
}
func (c *OpenAIClient) GetName() string {
return openAIClientName
}
// OpenAIHeaderTransport is an http.RoundTripper that adds the given headers to each request.
type OpenAIHeaderTransport struct {
Origin http.RoundTripper
Headers []http.Header
}
// RoundTrip implements the http.RoundTripper interface.
func (t *OpenAIHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original request
clonedReq := req.Clone(req.Context())
for _, header := range t.Headers {
for key, values := range header {
// Possible values per header: RFC 2616
for _, value := range values {
clonedReq.Header.Add(key, value)
}
}
}
return t.Origin.RoundTrip(clonedReq)
}
func (c *OpenAIClient) SetTools(tools []openai.Tool) {
c.tools = tools
}
func (c *OpenAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
cfg := openai.DefaultConfig(token)
orgId := config.GetOrganizationId()
proxyEndpoint := config.GetProxyEndpoint()
baseURL := config.GetBaseURL()
if baseURL != "" {
cfg.BaseURL = baseURL
}
transport := &http.Transport{}
if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport.Proxy = http.ProxyURL(proxyUrl)
} else {
klog.V(6).Info("openai client using default proxy from environment")
transport.Proxy = http.ProxyFromEnvironment
}
if orgId != "" {
cfg.OrgID = orgId
}
customHeaders := config.GetCustomHeaders()
cfg.HTTPClient = &http.Client{
Transport: &OpenAIHeaderTransport{
Origin: transport,
Headers: customHeaders,
},
}
client := openai.NewClientWithConfig(cfg)
if client == nil {
return errors.New("error creating OpenAI client")
}
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.maxHistory = config.GetMaxHistory()
c.memory = NewMemoryService()
return nil
}