diff --git a/main.go b/main.go index e4c50d0..0e0c1ec 100644 --- a/main.go +++ b/main.go @@ -100,15 +100,6 @@ type room struct { present map[*presentState]bool } -type watchState struct { - responseWriter http.ResponseWriter - flusher http.Flusher - room *room - client *client - admin bool - eventChan chan *event -} - type presentState struct { responseWriter http.ResponseWriter flusher http.Flusher @@ -422,31 +413,45 @@ func reset(w http.ResponseWriter, r *http.Request) { } func watch(w http.ResponseWriter, r *http.Request) { - ws := newWatchState(w, r) - if ws == nil { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming unsupported", http.StatusBadRequest) return } + client, eventChan := registerWatch(w, r) + if client == nil { + return + } + + // TODO: refcount client so it stays alive from just a watch + // add in registerwatch, sub in defer here + closeChan := w.(http.CloseNotifier).CloseNotify() ticker := time.NewTicker(15 * time.Second) - ws.sendInitial() + writeInitial(client, w, flusher) for { select { case <-closeChan: - ws.close() - return + close(eventChan) + + mu.Lock() + if client.eventChan == eventChan { + client.eventChan = nil + } + mu.Unlock() case <-ticker.C: - mu.Lock() - ws.sendHeartbeat() - mu.Unlock() + writeHeartbeat(w, flusher) - case event := <-ws.eventChan: - mu.Lock() - ws.sendEvent(event) - mu.Unlock() + case event, ok := <-eventChan: + if ok { + writeEvent(event, w, flusher) + } else { + return + } } } } @@ -545,89 +550,71 @@ func (rm *room) updateAllClients() { } } -func newWatchState(w http.ResponseWriter, r *http.Request) *watchState { +func registerWatch(w http.ResponseWriter, r *http.Request) (*client, chan *event) { mu.Lock() defer mu.Unlock() - ws := &watchState{ - responseWriter: w, - eventChan: make(chan *event, 100), - } - - var ok bool - ws.flusher, ok = w.(http.Flusher) - if !ok { - http.Error(ws.responseWriter, "streaming unsupported", http.StatusBadRequest) - return nil - } - roomId := r.URL.Query().Get("room_id") - ws.room = getRoom(roomId) + room := getRoom(roomId) clientId := r.URL.Query().Get("client_id") - ws.client = ws.room.getClient(clientId) + client := room.getClient(clientId) adminSecret := r.URL.Query().Get("admin_secret") if adminSecret != "" { - if adminSecret == ws.room.adminSecret() { - ws.admin = true + if adminSecret == room.adminSecret() { + client.Admin = true } else { http.Error(w, "invalid admin_secret", http.StatusBadRequest) - return nil + return nil, nil } } - ws.client.eventChan = ws.eventChan - ws.client.update() + if client.eventChan != nil { + close(client.eventChan) + } + + client.eventChan = make(chan *event, 100) + + client.update() w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - return ws + // Return eventChan because we're reading it with the lock held + return client, client.eventChan } -func (ws *watchState) sendInitial() { +func writeInitial(client *client, w http.ResponseWriter, flusher http.Flusher) { mu.Lock() defer mu.Unlock() - if !ws.admin { + if !client.Admin { return } - for _, client := range ws.room.clientById { - ws.sendEvent(&event{ + for _, iter := range client.room.clientById { + writeEvent(&event{ AdminEvent: &adminEvent{ - Client: client, + Client: iter, }, - }) + }, w, flusher) } - - ws.flusher.Flush() } -func (ws *watchState) sendHeartbeat() { - fmt.Fprintf(ws.responseWriter, ":\n\n") - ws.flusher.Flush() +func writeHeartbeat(w http.ResponseWriter, flusher http.Flusher) { + fmt.Fprintf(w, ":\n\n") + flusher.Flush() } -func (ws *watchState) sendEvent(e *event) { +func writeEvent(e *event, w http.ResponseWriter, flusher http.Flusher) { j, err := json.Marshal(e) if err != nil { log.Fatal(err) } - fmt.Fprintf(ws.responseWriter, "data: %s\n\n", j) - ws.flusher.Flush() -} - -func (ws *watchState) close() { - mu.Lock() - defer mu.Unlock() - - if ws.client.eventChan == ws.eventChan { - ws.client.eventChan = nil - close(ws.eventChan) - } + fmt.Fprintf(w, "data: %s\n\n", j) + flusher.Flush() } func newPresentState(w http.ResponseWriter, r *http.Request) *presentState {