diff --git a/main.go b/main.go index 312814e..4d7cb79 100644 --- a/main.go +++ b/main.go @@ -26,6 +26,7 @@ type ShortLinks struct { domainAliases map[string]string writableDomains map[string]bool + responseFormat *oaiResponseFormat } type setResponse struct { @@ -34,6 +35,11 @@ type setResponse struct { URL string `json:"url"` } +type suggestRequest struct { + Shorts []string `json:"shorts,omitempty"` + Title string `json:"title,omitempty"` +} + type suggestResponse struct { Shorts []string `json:"shorts"` Domain string `json:"domain"` @@ -96,6 +102,26 @@ func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains domainAliases: domainAliases, writableDomains: writableDomains, + responseFormat: &oaiResponseFormat{ + Type: "json_schema", + JSONSchema: map[string]any{ + "name": "suggest_response", + "strict": true, + "schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "shorts": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "string", + }, + }, + }, + "required": []string{"shorts"}, + "additionalProperties": false, + }, + }, + }, } sl.mux.HandleFunc("GET /{$}", sl.serveRoot) @@ -157,6 +183,7 @@ func (sl *ShortLinks) serveRootWithShort(w http.ResponseWriter, r *http.Request, "short": short, "host": sl.getDomain(r.Host), "long": r.Form.Get("long"), + "title": r.Form.Get("title"), }) if err != nil { sendError(w, http.StatusInternalServerError, "error executing template: %s", err) @@ -238,33 +265,32 @@ func (sl *ShortLinks) serveSuggest(w http.ResponseWriter, r *http.Request) { return } - if !r.Form.Has("shorts") { - sendError(w, http.StatusBadRequest, "shorts= param required") + if !r.Form.Has("shorts") && !r.Form.Has("title") { + sendError(w, http.StatusBadRequest, "shorts= or title= param required") return } - user := strings.Join(r.Form["shorts"], "\n") + in := suggestRequest{ + Shorts: r.Form["shorts"], + Title: r.Form.Get("title"), + } - comp, err := sl.oai.completeChat( - "You are an assistant helping a user choose useful short names for a URL shortener. The request contains a list recents names chosen by the user, separated by newlines, with the most recent names first. Respond with only a list of possible suggestions for additional short names, separated by newlines. In descending order of preference, suggestions should include: plural/singular variations, 2 and 3 letter abbreivations, conceptual variations, other variations that are likely to be useful. Your bar for suggestions should be relatively high; responding with a shorter list of high quality suggestions is preferred.", - user, + out := &suggestResponse{} + + err = sl.oai.completeChat( + "You are an assistant helping a user choose useful short names for a URL shortener. The request contains JSON object where the optional `shorts` key contains a list of recent names chosen by the user, with the most recent names first, and the optional `title` key contains a title for the URL. Respond with only a JSON object where the `shorts` key contains a list of possible suggestions for additional short names. In descending order of preference, suggestions should include: plural/singular variations, 2 and 3 letter abbreivations, conceptual variations, other variations that are likely to be useful. Your bar for suggestions should be relatively high; responding with a shorter list of high quality suggestions is preferred.", + in, + sl.responseFormat, + out, ) if err != nil { sendError(w, http.StatusInternalServerError, "oai.completeChat: %s", err) return } - shorts := []string{} - for _, short := range strings.Split(comp, "\n") { - if short != "" { - shorts = append(shorts, strings.TrimSpace(short)) - } - } + out.Domain = sl.getDomain(r.Host) - sendJSON(w, suggestResponse{ - Shorts: shorts, - Domain: sl.getDomain(r.Host), - }) + sendJSON(w, out) } func (sl *ShortLinks) serveHelp(w http.ResponseWriter, r *http.Request) { diff --git a/openai.go b/openai.go index a450b7b..071c706 100644 --- a/openai.go +++ b/openai.go @@ -15,8 +15,14 @@ type oaiClient struct { } type oaiRequest struct { - Model string `json:"model"` - Messages []oaiMessage `json:"messages"` + Model string `json:"model"` + Messages []oaiMessage `json:"messages"` + ResponseFormat *oaiResponseFormat `json:"response_format"` +} + +type oaiResponseFormat struct { + Type string `json:"type"` + JSONSchema map[string]interface{} `json:"json_schema"` } type oaiMessage struct { @@ -48,23 +54,28 @@ func newOAIClientFromEnv() (*oaiClient, error) { return newOAIClient(apiKey), nil } -func (oai *oaiClient) completeChat(system, user string) (string, error) { - buf := &bytes.Buffer{} - err := json.NewEncoder(buf).Encode(&oaiRequest{ +func (oai *oaiClient) completeChat(system string, in any, responseFormat *oaiResponseFormat, out any) error { + user, err := json.Marshal(in) + if err != nil { + return err + } + + reqBody, err := json.Marshal(&oaiRequest{ Model: "gpt-4o", Messages: []oaiMessage{ {Role: "system", Content: system}, - {Role: "user", Content: user}, + {Role: "user", Content: string(user)}, }, + ResponseFormat: responseFormat, }) if err != nil { - return "", err + return err } - req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", buf) + req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewReader(reqBody)) if err != nil { - return "", err + return err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oai.apiKey)) @@ -72,23 +83,28 @@ func (oai *oaiClient) completeChat(system, user string) (string, error) { resp, err := oai.c.Do(req) if err != nil { - return "", err + return err } defer resp.Body.Close() if resp.StatusCode != 200 { body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("%s", string(body)) + return fmt.Errorf("%s", string(body)) } dec := json.NewDecoder(resp.Body) - var res oaiResponse + res := &oaiResponse{} - err = dec.Decode(&res) + err = dec.Decode(res) if err != nil { - return "", err + return err } - return res.Choices[0].Message.Content, nil + err = json.Unmarshal([]byte(res.Choices[0].Message.Content), out) + if err != nil { + return err + } + + return nil } diff --git a/static/help.html b/static/help.html index 874342d..352fbc3 100644 --- a/static/help.html +++ b/static/help.html @@ -211,6 +211,7 @@ a { {{ .readHost }} diff --git a/static/index.html b/static/index.html index 1184a9e..031fb5b 100644 --- a/static/index.html +++ b/static/index.html @@ -62,6 +62,8 @@ sl-icon[name="check-square-fill"] { />