Files
web_server/backend/services/websocket.go
Adam French 29350af2e0
Some checks failed
Deploy with Docker Compose / deploy (push) Has been cancelled
Fix WebSocket 403 in dev mode by allowing localhost origins
The CheckOrigin function only accepted the production domain, rejecting
localhost connections in dev. Also removed redundant error response after
a failed upgrade since the upgrader already writes its own HTTP response.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-25 16:59:13 +00:00

111 lines
2.1 KiB
Go

package services
import (
"net/http"
"strings"
"sync"
"time"
"adam-french.co.uk/backend/models"
"gorm.io/gorm"
"github.com/gorilla/websocket"
)
const maxMessages = 50
var allowedDomain string
var Upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return false
}
origin = strings.TrimPrefix(origin, "https://")
origin = strings.TrimPrefix(origin, "http://")
// Strip port for localhost comparisons (e.g. "localhost:80")
host := strings.Split(origin, ":")[0]
return origin == allowedDomain || origin == "www."+allowedDomain || host == "localhost"
},
}
var (
clients = make(map[*websocket.Conn]bool)
mu sync.Mutex
wsDB *gorm.DB
nextAuthorID uint
)
const (
rateLimitWindow = time.Second
rateLimitMaxMsgs = 10
)
func InitWebSocket(database *gorm.DB, domain string) {
wsDB = database
allowedDomain = domain
}
func HandleWebSocket(conn *websocket.Conn) {
defer conn.Close()
mu.Lock()
clients[conn] = true
nextAuthorID++
authorID := nextAuthorID
var history []models.Message
wsDB.Order("created_at ASC").Limit(maxMessages).Find(&history)
for _, msg := range history {
if err := conn.WriteJSON(msg); err != nil {
mu.Unlock()
return
}
}
mu.Unlock()
msgCount := 0
windowStart := time.Now()
for {
var incoming models.Message
if err := conn.ReadJSON(&incoming); err != nil {
break
}
now := time.Now()
if now.Sub(windowStart) > rateLimitWindow {
msgCount = 0
windowStart = now
}
msgCount++
if msgCount > rateLimitMaxMsgs {
continue
}
incoming.AuthorID = authorID
mu.Lock()
wsDB.Create(&incoming)
wsDB.Where("id NOT IN (?)",
wsDB.Model(&models.Message{}).Select("id").Order("created_at DESC").Limit(maxMessages),
).Delete(&models.Message{})
for client := range clients {
if err := client.WriteJSON(incoming); err != nil {
client.Close()
delete(clients, client)
}
}
mu.Unlock()
}
mu.Lock()
delete(clients, conn)
mu.Unlock()
}