diff --git a/main.go b/main.go
index 312814e..4d7cb79 100644
--- a/main.go
+++ b/main.go
@@ -26,6 +26,7 @@ type ShortLinks struct {
domainAliases map[string]string
writableDomains map[string]bool
+ responseFormat *oaiResponseFormat
}
type setResponse struct {
@@ -34,6 +35,11 @@ type setResponse struct {
URL string `json:"url"`
}
+type suggestRequest struct {
+ Shorts []string `json:"shorts,omitempty"`
+ Title string `json:"title,omitempty"`
+}
+
type suggestResponse struct {
Shorts []string `json:"shorts"`
Domain string `json:"domain"`
@@ -96,6 +102,26 @@ func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains
domainAliases: domainAliases,
writableDomains: writableDomains,
+ responseFormat: &oaiResponseFormat{
+ Type: "json_schema",
+ JSONSchema: map[string]any{
+ "name": "suggest_response",
+ "strict": true,
+ "schema": map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "shorts": map[string]any{
+ "type": "array",
+ "items": map[string]any{
+ "type": "string",
+ },
+ },
+ },
+ "required": []string{"shorts"},
+ "additionalProperties": false,
+ },
+ },
+ },
}
sl.mux.HandleFunc("GET /{$}", sl.serveRoot)
@@ -157,6 +183,7 @@ func (sl *ShortLinks) serveRootWithShort(w http.ResponseWriter, r *http.Request,
"short": short,
"host": sl.getDomain(r.Host),
"long": r.Form.Get("long"),
+ "title": r.Form.Get("title"),
})
if err != nil {
sendError(w, http.StatusInternalServerError, "error executing template: %s", err)
@@ -238,33 +265,32 @@ func (sl *ShortLinks) serveSuggest(w http.ResponseWriter, r *http.Request) {
return
}
- if !r.Form.Has("shorts") {
- sendError(w, http.StatusBadRequest, "shorts= param required")
+ if !r.Form.Has("shorts") && !r.Form.Has("title") {
+ sendError(w, http.StatusBadRequest, "shorts= or title= param required")
return
}
- user := strings.Join(r.Form["shorts"], "\n")
+ in := suggestRequest{
+ Shorts: r.Form["shorts"],
+ Title: r.Form.Get("title"),
+ }
- comp, err := sl.oai.completeChat(
- "You are an assistant helping a user choose useful short names for a URL shortener. The request contains a list recents names chosen by the user, separated by newlines, with the most recent names first. Respond with only a list of possible suggestions for additional short names, separated by newlines. In descending order of preference, suggestions should include: plural/singular variations, 2 and 3 letter abbreivations, conceptual variations, other variations that are likely to be useful. Your bar for suggestions should be relatively high; responding with a shorter list of high quality suggestions is preferred.",
- user,
+ out := &suggestResponse{}
+
+ err = sl.oai.completeChat(
+ "You are an assistant helping a user choose useful short names for a URL shortener. The request contains JSON object where the optional `shorts` key contains a list of recent names chosen by the user, with the most recent names first, and the optional `title` key contains a title for the URL. Respond with only a JSON object where the `shorts` key contains a list of possible suggestions for additional short names. In descending order of preference, suggestions should include: plural/singular variations, 2 and 3 letter abbreivations, conceptual variations, other variations that are likely to be useful. Your bar for suggestions should be relatively high; responding with a shorter list of high quality suggestions is preferred.",
+ in,
+ sl.responseFormat,
+ out,
)
if err != nil {
sendError(w, http.StatusInternalServerError, "oai.completeChat: %s", err)
return
}
- shorts := []string{}
- for _, short := range strings.Split(comp, "\n") {
- if short != "" {
- shorts = append(shorts, strings.TrimSpace(short))
- }
- }
+ out.Domain = sl.getDomain(r.Host)
- sendJSON(w, suggestResponse{
- Shorts: shorts,
- Domain: sl.getDomain(r.Host),
- })
+ sendJSON(w, out)
}
func (sl *ShortLinks) serveHelp(w http.ResponseWriter, r *http.Request) {
diff --git a/openai.go b/openai.go
index a450b7b..071c706 100644
--- a/openai.go
+++ b/openai.go
@@ -15,8 +15,14 @@ type oaiClient struct {
}
type oaiRequest struct {
- Model string `json:"model"`
- Messages []oaiMessage `json:"messages"`
+ Model string `json:"model"`
+ Messages []oaiMessage `json:"messages"`
+ ResponseFormat *oaiResponseFormat `json:"response_format"`
+}
+
+type oaiResponseFormat struct {
+ Type string `json:"type"`
+ JSONSchema map[string]interface{} `json:"json_schema"`
}
type oaiMessage struct {
@@ -48,23 +54,28 @@ func newOAIClientFromEnv() (*oaiClient, error) {
return newOAIClient(apiKey), nil
}
-func (oai *oaiClient) completeChat(system, user string) (string, error) {
- buf := &bytes.Buffer{}
- err := json.NewEncoder(buf).Encode(&oaiRequest{
+func (oai *oaiClient) completeChat(system string, in any, responseFormat *oaiResponseFormat, out any) error {
+ user, err := json.Marshal(in)
+ if err != nil {
+ return err
+ }
+
+ reqBody, err := json.Marshal(&oaiRequest{
Model: "gpt-4o",
Messages: []oaiMessage{
{Role: "system", Content: system},
- {Role: "user", Content: user},
+ {Role: "user", Content: string(user)},
},
+ ResponseFormat: responseFormat,
})
if err != nil {
- return "", err
+ return err
}
- req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", buf)
+ req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewReader(reqBody))
if err != nil {
- return "", err
+ return err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oai.apiKey))
@@ -72,23 +83,28 @@ func (oai *oaiClient) completeChat(system, user string) (string, error) {
resp, err := oai.c.Do(req)
if err != nil {
- return "", err
+ return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
- return "", fmt.Errorf("%s", string(body))
+ return fmt.Errorf("%s", string(body))
}
dec := json.NewDecoder(resp.Body)
- var res oaiResponse
+ res := &oaiResponse{}
- err = dec.Decode(&res)
+ err = dec.Decode(res)
if err != nil {
- return "", err
+ return err
}
- return res.Choices[0].Message.Content, nil
+ err = json.Unmarshal([]byte(res.Choices[0].Message.Content), out)
+ if err != nil {
+ return err
+ }
+
+ return nil
}
diff --git a/static/help.html b/static/help.html
index 874342d..352fbc3 100644
--- a/static/help.html
+++ b/static/help.html
@@ -211,6 +211,7 @@ a {
{{ .readHost }}
diff --git a/static/index.html b/static/index.html
index 1184a9e..031fb5b 100644
--- a/static/index.html
+++ b/static/index.html
@@ -62,6 +62,8 @@ sl-icon[name="check-square-fill"] {
/>