Some checks failed
Deploy with Docker Compose / deploy (push) Has been cancelled
109 lines
1.9 KiB
Go
109 lines
1.9 KiB
Go
package services
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"adam-french.co.uk/backend/models"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const maxMessages = 50
|
|
|
|
var allowedDomain string
|
|
|
|
var Upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return false
|
|
}
|
|
origin = strings.TrimPrefix(origin, "https://")
|
|
origin = strings.TrimPrefix(origin, "http://")
|
|
return origin == allowedDomain || origin == "www."+allowedDomain
|
|
},
|
|
}
|
|
|
|
var (
|
|
clients = make(map[*websocket.Conn]bool)
|
|
mu sync.Mutex
|
|
wsDB *gorm.DB
|
|
nextAuthorID uint
|
|
)
|
|
|
|
const (
|
|
rateLimitWindow = time.Second
|
|
rateLimitMaxMsgs = 10
|
|
)
|
|
|
|
func InitWebSocket(database *gorm.DB, domain string) {
|
|
wsDB = database
|
|
allowedDomain = domain
|
|
}
|
|
|
|
func HandleWebSocket(conn *websocket.Conn) {
|
|
defer conn.Close()
|
|
|
|
mu.Lock()
|
|
clients[conn] = true
|
|
nextAuthorID++
|
|
authorID := nextAuthorID
|
|
|
|
var history []models.Message
|
|
wsDB.Order("created_at ASC").Limit(maxMessages).Find(&history)
|
|
|
|
for _, msg := range history {
|
|
if err := conn.WriteJSON(msg); err != nil {
|
|
mu.Unlock()
|
|
return
|
|
}
|
|
}
|
|
mu.Unlock()
|
|
|
|
msgCount := 0
|
|
windowStart := time.Now()
|
|
|
|
for {
|
|
var incoming models.Message
|
|
if err := conn.ReadJSON(&incoming); err != nil {
|
|
break
|
|
}
|
|
|
|
now := time.Now()
|
|
if now.Sub(windowStart) > rateLimitWindow {
|
|
msgCount = 0
|
|
windowStart = now
|
|
}
|
|
msgCount++
|
|
if msgCount > rateLimitMaxMsgs {
|
|
continue
|
|
}
|
|
|
|
incoming.AuthorID = authorID
|
|
|
|
mu.Lock()
|
|
wsDB.Create(&incoming)
|
|
wsDB.Where("id NOT IN (?)",
|
|
wsDB.Model(&models.Message{}).Select("id").Order("created_at DESC").Limit(maxMessages),
|
|
).Delete(&models.Message{})
|
|
|
|
for client := range clients {
|
|
if err := client.WriteJSON(incoming); err != nil {
|
|
client.Close()
|
|
delete(clients, client)
|
|
}
|
|
}
|
|
mu.Unlock()
|
|
}
|
|
|
|
mu.Lock()
|
|
delete(clients, conn)
|
|
mu.Unlock()
|
|
}
|