new error codes, code refactor, added getMessages function

This commit is contained in:
slawk0
2025-02-07 22:59:36 +01:00
parent 51b426ba54
commit 7342c5b483
9 changed files with 189 additions and 31 deletions

View File

@@ -313,7 +313,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) {
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "Failed to process contacts", err) return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process contacts: %w", err))
} }
return contacts, nil return contacts, nil
@@ -342,7 +342,7 @@ func ContactSuggestion(db *sql.DB, contactUsername string) ([]string, error) {
suggestions = append(suggestions, suggestion) suggestions = append(suggestions, suggestion)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "Error processing suggestions", err) return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process suggestions"))
} }
return suggestions, nil return suggestions, nil
} }

109
database/messages.go Normal file
View File

@@ -0,0 +1,109 @@
package database
import (
"database/sql"
"fmt"
"github.com/google/uuid"
"relay-server/helpers"
"relay-server/model"
)
func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit int, cursor int) ([]*model.Message, error) {
_, err := checkMembership(db, userID, conversationID)
if err != nil {
return nil, err
}
var query string
var messages []*model.Message
if cursor != 0 {
query = `
SELECT
m.message_id,
m.content AS message,
m.sent_at,
m.attachment_urls,
a.username AS sender
FROM Messages m
JOIN Accounts a ON m.user_id = a.user_id
WHERE m.conversation_id = $1
AND m.message_id < $2
ORDER BY m.message_id DESC
LIMIT $3;
`
rows, err := db.Query(query, conversationID, cursor, limit)
if err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err))
}
for rows.Next() {
message := &model.Message{}
err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender)
if err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err))
}
messages = append(messages, message)
}
if err = rows.Err(); err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err))
}
} else {
query = `
SELECT
m.message_id,
m.content AS message,
m.sent_at,
m.attachment_urls,
a.username AS sender
FROM Messages m
JOIN Accounts a ON m.user_id = a.user_id
WHERE m.conversation_id = $1
ORDER BY m.message_id DESC
LIMIT $2;
`
rows, err := db.Query(query, conversationID, cursor, limit)
if err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err))
}
for rows.Next() {
message := &model.Message{}
err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender)
if err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err))
}
messages = append(messages, message)
}
if err = rows.Err(); err != nil {
return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err))
}
}
if cursor != 0 {
// Reverse first messages
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
}
return messages, nil
}
func checkMembership(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bool, error) {
query := `
SELECT EXISTS (
SELECT 1
FROM Memberships
WHERE user_id = $1
AND conversation_id = $2
) AS is_member;
`
var isMember bool
err := db.QueryRow(query, userID, conversationID).Scan(&isMember)
if err != nil {
return false, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check membership: %w", err))
}
if !isMember {
return false, helpers.NewError(helpers.ErrForbidden, "You are member of the conversation", nil)
}
return isMember, nil
}

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"os" "os"
"relay-server/config" "relay-server/config"
@@ -158,12 +159,13 @@ func Login(c *fiber.Ctx) error {
func ValidateToken(c *fiber.Ctx) error { func ValidateToken(c *fiber.Ctx) error {
username, ok := c.Locals("username").(string) username, ok := c.Locals("username").(string)
if !ok { if !ok {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing username", nil) return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing username", fmt.Errorf("missing username: %v", c.Locals("username")))
} }
userID, ok := c.Locals("userID").(string) userIDVal := c.Locals("userID")
userID, ok := userIDVal.(uuid.UUID)
if !ok { if !ok {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing user ID", nil) return helpers.NewError(helpers.ErrUnauthorized, "unauthorized", fmt.Errorf("missing/invalid userID type: %T, value: %v\n", userIDVal, userIDVal))
} }
return c.Status(fiber.StatusOK).JSON(fiber.Map{ return c.Status(fiber.StatusOK).JSON(fiber.Map{

View File

@@ -35,19 +35,10 @@ func DeleteContact(c *fiber.Ctx) error {
func InsertContact(c *fiber.Ctx) error { func InsertContact(c *fiber.Ctx) error {
type params struct { type params struct {
ContactUsername string `params:"contact_username"` ContactUsername string `params:"contactUsername"`
} }
userIDVal := c.Locals("userID") userID := c.Locals("userID").(uuid.UUID)
userIDStr, ok := userIDVal.(string)
if !ok {
return helpers.NewError(helpers.ErrInternal, "Internal server error", nil)
}
userID, err := uuid.Parse(userIDStr)
if err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid user ID format", err)
}
p := new(params) p := new(params)
if err := c.ParamsParser(p); err != nil { if err := c.ParamsParser(p); err != nil {
@@ -76,16 +67,7 @@ func InsertContact(c *fiber.Ctx) error {
} }
func GetContacts(c *fiber.Ctx) error { func GetContacts(c *fiber.Ctx) error {
userIDVal := c.Locals("userID") userID := c.Locals("userID").(uuid.UUID)
userIDStr, ok := userIDVal.(string)
if !ok {
return helpers.NewError(helpers.ErrInternal, "Invalid user session", nil)
}
userID, err := uuid.Parse(userIDStr)
if err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid user ID format", err)
}
contacts, err := database.GetContacts(database.DB, userID) contacts, err := database.GetContacts(database.DB, userID)
if err != nil { if err != nil {

41
handlers/messages.go Normal file
View File

@@ -0,0 +1,41 @@
package handlers
import (
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"relay-server/database"
"relay-server/helpers"
)
func GetMessages(c *fiber.Ctx) error {
type params struct {
conversationID uuid.UUID `params:"conversationID"`
}
type query struct {
limit int `query:"limit"`
cursor int `query:"cursor"`
}
userID := c.Locals("userID").(uuid.UUID)
p := new(params)
q := new(query)
if err := c.ParamsParser(p); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err)
}
if err := c.QueryParser(q); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid query", err)
}
if p.conversationID == uuid.Nil {
return helpers.NewError(helpers.ErrInvalidInput, "conversation ID is empty", nil)
}
messages, err := database.GetMessages(database.DB, userID, p.conversationID, q.limit, q.cursor)
if err != nil {
return err
}
if len(messages) == 0 {
return c.JSON(fiber.Map{"message": "No messages found"})
}
return c.JSON(fiber.Map{"messages": messages})
}

View File

@@ -11,6 +11,8 @@ const (
ErrInternal ErrorCode = "INTERNAL_ERROR" ErrInternal ErrorCode = "INTERNAL_ERROR"
ErrNotFound ErrorCode = "NOT_FOUND" ErrNotFound ErrorCode = "NOT_FOUND"
ErrInvalidInput ErrorCode = "INVALID_INPUT" ErrInvalidInput ErrorCode = "INVALID_INPUT"
ErrForbidden ErrorCode = "FORBIDDEN"
ErrUnauthorized ErrorCode = "UNAUTHORIZED"
) )
type Error struct { type Error struct {
@@ -40,6 +42,10 @@ func NewError(code ErrorCode, userMsg string, internalErr error) *Error {
statusCode = fiber.StatusBadRequest statusCode = fiber.StatusBadRequest
case ErrInternal: case ErrInternal:
statusCode = fiber.StatusInternalServerError statusCode = fiber.StatusInternalServerError
case ErrForbidden:
statusCode = fiber.StatusForbidden
case ErrUnauthorized:
statusCode = fiber.StatusUnauthorized
} }
return &Error{ return &Error{

View File

@@ -1,10 +1,13 @@
package middleware package middleware
import ( import (
"fmt"
jwtware "github.com/gofiber/contrib/jwt" jwtware "github.com/gofiber/contrib/jwt"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"os" "os"
"relay-server/helpers"
"relay-server/model" "relay-server/model"
) )
@@ -17,8 +20,11 @@ func Protected() fiber.Handler {
SuccessHandler: func(c *fiber.Ctx) error { SuccessHandler: func(c *fiber.Ctx) error {
user := c.Locals("user").(*jwt.Token) user := c.Locals("user").(*jwt.Token)
claims := user.Claims.(*model.UserClaims) claims := user.Claims.(*model.UserClaims)
userID, err := uuid.Parse(claims.UserID)
c.Locals("userID", claims.UserID) if err != nil {
return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to parse user ID: %w", err))
}
c.Locals("userID", userID)
c.Locals("username", claims.Username) c.Locals("username", claims.Username)
return c.Next() return c.Next()
}, },

View File

@@ -25,3 +25,11 @@ type Contact struct {
type ContactSuggestion struct { type ContactSuggestion struct {
Username string `json:"username"` Username string `json:"username"`
} }
type Message struct {
MessageID int `json:"message_id"`
Message string `json:"message"`
SentAt string `json:"sent_at"`
Sender string `json:"sender"`
AttachmentUrl string `json:"attachment_url"`
}

View File

@@ -21,9 +21,13 @@ func SetupRoutes(app *fiber.App) {
auth.Get("/validate", middleware.Protected(), handlers.ValidateToken) auth.Get("/validate", middleware.Protected(), handlers.ValidateToken)
// Contacts group // Contacts group
contacts := chat.Group("/contacts", middleware.Protected(), logger.New()) contacts := chat.Group("/contact", middleware.Protected(), logger.New())
contacts.Delete("/:contact_id/:conversation_id", handlers.DeleteContact) contacts.Delete("/:contactID/:conversation_id", handlers.DeleteContact)
contacts.Post("/:contact_username", handlers.InsertContact) contacts.Post("/:contactUsername", handlers.InsertContact)
contacts.Get("/", handlers.GetContacts) contacts.Get("/", handlers.GetContacts)
contacts.Get("/suggestions/:contactUsername", handlers.GetContactSuggestions) contacts.Get("/suggestions/:contactUsername", handlers.GetContactSuggestions)
// Messages group
messages := chat.Group("/messages", middleware.Protected(), logger.New())
messages.Get("/:conversationID", handlers.GetMessages)
} }