From 9648518a6943a530b23e79a3f4b8689b804544ab Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Mon, 2 Dec 2024 22:54:22 -0800 Subject: [PATCH] Wildcard domain support --- main.go | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index ad64919..e235882 100644 --- a/main.go +++ b/main.go @@ -72,7 +72,19 @@ func (sl *ShortLinks) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (sl *ShortLinks) serveRoot(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, sl.getDomain(r.Host), r.URL) - sl.serveRootWithPath(w, r, "") + parts := strings.SplitN(r.Host, ".", 2) + if len(parts) != 2 { + sl.serveRootWithPath(w, r, "") + return + } + + long, err := sl.getLong(parts[0], sl.getDomain(parts[1])) + if err != nil { + sl.serveRootWithPath(w, r, "") + return + } + + http.Redirect(w, r, long, http.StatusTemporaryRedirect) } func (sl *ShortLinks) serveRootWithPath(w http.ResponseWriter, r *http.Request, path string) { @@ -81,6 +93,7 @@ func (sl *ShortLinks) serveRootWithPath(w http.ResponseWriter, r *http.Request, sendError(w, http.StatusBadRequest, "Parse form: %s", err) return } + log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, sl.getDomain(r.Host), r.URL) if !sl.isWritable(r.Host) { @@ -104,9 +117,7 @@ func (sl *ShortLinks) serveShort(w http.ResponseWriter, r *http.Request) { short := r.PathValue("short") - row := sl.db.QueryRow(`SELECT long FROM links WHERE short = $1 AND domain = $2`, short, sl.getDomain(r.Host)) - var long string - err := row.Scan(&long) + long, err := sl.getLong(short, sl.getDomain(r.Host)) if err != nil { sl.serveRootWithPath(w, r, short) return @@ -237,6 +248,16 @@ func (sl *ShortLinks) isWritable(host string) bool { return sl.writableDomains[host] } +func (sl *ShortLinks) getLong(short, domain string) (string, error) { + var long string + err := sl.db.QueryRow("SELECT long FROM links WHERE short = $1 AND domain = $2", short, domain).Scan(&long) + if err != nil { + return "", err + } + + return long, nil +} + func main() { port := os.Getenv("PORT") if port == "" {