This commit is contained in:
Ian Gulliver
2024-12-05 00:07:04 -08:00
parent 592198cbdd
commit 0b800b767b
3 changed files with 193 additions and 14 deletions

76
main.go
View File

@@ -18,6 +18,7 @@ import (
type ShortLinks struct {
tmpl *template.Template
help *template.Template
list *template.Template
mux *http.ServeMux
db *sql.DB
r *rand.Rand
@@ -38,10 +39,16 @@ type suggestResponse struct {
Domain string `json:"domain"`
}
func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains map[string]bool) (*ShortLinks, error) {
tmpl := template.New("index.html")
type link struct {
Short string `json:"short"`
Long string `json:"long"`
Domain string `json:"domain"`
URL string `json:"url"`
Generated bool `json:"generated"`
}
tmpl, err := tmpl.ParseFiles("static/index.html")
func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains map[string]bool) (*ShortLinks, error) {
tmpl, err := template.New("index.html").ParseFiles("static/index.html")
if err != nil {
return nil, fmt.Errorf("static/index.html: %w", err)
}
@@ -51,6 +58,11 @@ func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains
return nil, fmt.Errorf("static/help.html: %w", err)
}
list, err := template.New("list.html").ParseFiles("static/list.html")
if err != nil {
return nil, fmt.Errorf("static/list.html: %w", err)
}
oai, err := newOAIClientFromEnv()
if err != nil {
return nil, fmt.Errorf("newOAIClientFromEnv: %w", err)
@@ -59,6 +71,7 @@ func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains
sl := &ShortLinks{
tmpl: tmpl,
help: help,
list: list,
mux: http.NewServeMux(),
db: db,
r: rand.New(rand.NewSource(uint64(time.Now().UnixNano()))),
@@ -69,8 +82,9 @@ func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains
}
sl.mux.HandleFunc("GET /{$}", sl.serveRoot)
sl.mux.HandleFunc("GET /_help", sl.serveHelp)
sl.mux.HandleFunc("GET /_favicon.png", sl.serveFavicon)
sl.mux.HandleFunc("GET /_help", sl.serveHelp)
sl.mux.HandleFunc("GET /_list", sl.serveList)
sl.mux.HandleFunc("GET /{short}", sl.serveShort)
sl.mux.HandleFunc("POST /{$}", sl.serveSet)
sl.mux.HandleFunc("QUERY /{$}", sl.serveSuggest)
@@ -91,26 +105,26 @@ func (sl *ShortLinks) serveRoot(w http.ResponseWriter, r *http.Request) {
}
if sl.isWritable(r.Host) {
sl.serveRootWithPath(w, r, "")
sl.serveRootWithShort(w, r, r.Form.Get("short"))
return
}
parts := strings.SplitN(r.Host, ".", 2)
if len(parts) != 2 {
sl.serveRootWithPath(w, r, "")
sl.serveRootWithShort(w, r, r.Form.Get("short"))
return
}
long, err := sl.getLong(parts[0], sl.getDomain(parts[1]))
if err != nil {
sl.serveRootWithPath(w, r, "")
sl.serveRootWithShort(w, r, r.Form.Get("short"))
return
}
http.Redirect(w, r, long, http.StatusTemporaryRedirect)
}
func (sl *ShortLinks) serveRootWithPath(w http.ResponseWriter, r *http.Request, path string) {
func (sl *ShortLinks) serveRootWithShort(w http.ResponseWriter, r *http.Request, short string) {
err := sl.initRequest(w, r)
if err != nil {
sendError(w, http.StatusBadRequest, "init request: %s", err)
@@ -123,9 +137,9 @@ func (sl *ShortLinks) serveRootWithPath(w http.ResponseWriter, r *http.Request,
}
err = sl.tmpl.Execute(w, map[string]any{
"path": path,
"host": sl.getDomain(r.Host),
"long": r.Form.Get("long"),
"short": short,
"host": sl.getDomain(r.Host),
"long": r.Form.Get("long"),
})
if err != nil {
sendError(w, http.StatusInternalServerError, "error executing template: %s", err)
@@ -144,7 +158,7 @@ func (sl *ShortLinks) serveShort(w http.ResponseWriter, r *http.Request) {
long, err := sl.getLong(short, sl.getDomain(r.Host))
if err != nil {
sl.serveRootWithPath(w, r, short)
sl.serveRootWithShort(w, r, short)
return
}
@@ -303,6 +317,44 @@ func (sl *ShortLinks) serveFavicon(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "static/favicon.png")
}
func (sl *ShortLinks) serveList(w http.ResponseWriter, r *http.Request) {
err := sl.initRequest(w, r)
if err != nil {
sendError(w, http.StatusBadRequest, "init request: %s", err)
return
}
rows, err := sl.db.Query("SELECT short, long, domain, generated FROM links WHERE domain = $1 ORDER BY short ASC", sl.getDomain(r.Host))
if err != nil {
sendError(w, http.StatusInternalServerError, "select links: %s", err)
return
}
defer rows.Close()
links := []link{}
for rows.Next() {
link := link{}
err := rows.Scan(&link.Short, &link.Long, &link.Domain, &link.Generated)
if err != nil {
sendError(w, http.StatusInternalServerError, "scan link: %s", err)
return
}
link.URL = fmt.Sprintf("https://%s/%s", link.Domain, link.Short)
links = append(links, link)
}
err = sl.list.Execute(w, map[string]any{
"links": links,
})
if err != nil {
sendError(w, http.StatusInternalServerError, "error executing template: %s", err)
return
}
}
func (sl *ShortLinks) getDomain(host string) string {
if alias, ok := sl.domainAliases[host]; ok {
return alias