diff --git a/main.go b/main.go index 72f3fc3..2968be6 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,9 @@ type ShortLinks struct { db *sql.DB r *rand.Rand oai *oaiClient + + domainAliases map[string]string + writableDomains map[string]bool } type setResponse struct { @@ -30,7 +33,7 @@ type suggestResponse struct { Shorts []string `json:"shorts"` } -func NewShortLinks(db *sql.DB) (*ShortLinks, error) { +func NewShortLinks(db *sql.DB, domainAliases map[string]string, writableDomains map[string]bool) (*ShortLinks, error) { tmpl := template.New("index.html") tmpl, err := tmpl.ParseFiles("static/index.html") @@ -49,6 +52,9 @@ func NewShortLinks(db *sql.DB) (*ShortLinks, error) { db: db, r: rand.New(rand.NewSource(uint64(time.Now().UnixNano()))), oai: oai, + + domainAliases: domainAliases, + writableDomains: writableDomains, } sl.mux.HandleFunc("GET /{$}", sl.serveRoot) @@ -64,13 +70,22 @@ func (sl *ShortLinks) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (sl *ShortLinks) serveRoot(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s %s %s", r.RemoteAddr, r.Method, r.Host, r.URL) + log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, sl.getDomain(r.Host), r.URL) + sl.serveRootWithPath(w, r, "") } func (sl *ShortLinks) serveRootWithPath(w http.ResponseWriter, r *http.Request, path string) { + log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, sl.getDomain(r.Host), r.URL) + + if !sl.isWritable(r.Host) { + sendError(w, http.StatusForbidden, "not writable") + return + } + err := sl.tmpl.Execute(w, map[string]any{ "path": path, + "host": sl.getDomain(r.Host), }) if err != nil { sendError(w, http.StatusInternalServerError, "error executing template: %s", err) @@ -79,11 +94,11 @@ func (sl *ShortLinks) serveRootWithPath(w http.ResponseWriter, r *http.Request, } func (sl *ShortLinks) serveShort(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s %s %s", r.RemoteAddr, r.Method, r.Host, r.URL) + log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, sl.getDomain(r.Host), r.URL) short := r.PathValue("short") - row := sl.db.QueryRow(`SELECT long FROM links WHERE short = $1 AND domain = $2`, short, r.Host) + row := sl.db.QueryRow(`SELECT long FROM links WHERE short = $1 AND domain = $2`, short, sl.getDomain(r.Host)) var long string err := row.Scan(&long) if err != nil { @@ -101,13 +116,18 @@ func (sl *ShortLinks) serveSet(w http.ResponseWriter, r *http.Request) { return } - log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, r.URL, r.Form.Encode()) + log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, sl.getDomain(r.Host), r.URL) + + if !sl.isWritable(r.Host) { + sendError(w, http.StatusForbidden, "not writable") + return + } short := r.Form.Get("short") generated := false if short == "" { - short, err = sl.genShort(r.Host) + short, err = sl.genShort(sl.getDomain(r.Host)) if err != nil { sendError(w, http.StatusInternalServerError, "genShort: %s", err) return @@ -122,7 +142,7 @@ func (sl *ShortLinks) serveSet(w http.ResponseWriter, r *http.Request) { return } - _, err = sl.db.Exec(`SELECT update_link($1, $2, $3, $4);`, short, long, r.Host, generated) + _, err = sl.db.Exec(`SELECT update_link($1, $2, $3, $4);`, short, long, sl.getDomain(r.Host), generated) if err != nil { sendError(w, http.StatusInternalServerError, "update_link: %s", err) return @@ -140,7 +160,12 @@ func (sl *ShortLinks) serveSuggest(w http.ResponseWriter, r *http.Request) { return } - log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, r.URL, r.Form.Encode()) + log.Printf("%s %s %s %s %s", r.RemoteAddr, r.Method, r.Host, sl.getDomain(r.Host), r.URL) + + if !sl.isWritable(r.Host) { + sendError(w, http.StatusForbidden, "not writable") + return + } if !r.Form.Has("short") { sendError(w, http.StatusBadRequest, "short= param required") @@ -194,6 +219,18 @@ func (sl *ShortLinks) genShort(domain string) (string, error) { return "", fmt.Errorf("no available short link found") } +func (sl *ShortLinks) getDomain(host string) string { + if alias, ok := sl.domainAliases[host]; ok { + return alias + } + + return host +} + +func (sl *ShortLinks) isWritable(host string) bool { + return sl.writableDomains[host] +} + func main() { port := os.Getenv("PORT") if port == "" { @@ -266,7 +303,17 @@ func main() { } } - sl, err := NewShortLinks(db) + domainAliases, err := loadDomainAliases() + if err != nil { + log.Fatalf("Failed to load domain aliases: %v", err) + } + + writableDomains, err := loadWritableDomains() + if err != nil { + log.Fatalf("Failed to load writable domains: %v", err) + } + + sl, err := NewShortLinks(db, domainAliases, writableDomains) if err != nil { log.Fatalf("Failed to create shortlinks: %v", err) } @@ -280,3 +327,38 @@ func main() { log.Fatalf("listen: %s", err) } } + +func loadDomainAliases() (map[string]string, error) { + ret := map[string]string{} + + s := os.Getenv("DOMAIN_ALIASES") + if s == "" { + return ret, nil + } + + for _, pair := range strings.Split(s, ",") { + parts := strings.SplitN(pair, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid domain alias: %s", pair) + } + + ret[parts[0]] = parts[1] + } + + return ret, nil +} + +func loadWritableDomains() (map[string]bool, error) { + ret := map[string]bool{} + + s := os.Getenv("WRITABLE_DOMAINS") + if s == "" { + return ret, nil + } + + for _, domain := range strings.Split(s, ",") { + ret[domain] = true + } + + return ret, nil +} diff --git a/static/index.html b/static/index.html index 7de585f..f815165 100644 --- a/static/index.html +++ b/static/index.html @@ -222,8 +222,6 @@ document.addEventListener('DOMContentLoaded', async () => { customElements.whenDefined('sl-tree'), ]); - document.getElementById('short').setAttribute('label', `${window.location.host}/`); - let shortPaste = false; document.getElementById('short').addEventListener('sl-input', async () => { @@ -295,7 +293,7 @@ document.addEventListener('DOMContentLoaded', async () => {