code refactor, implementing sockets

This commit is contained in:
slawk0
2025-02-11 16:34:59 +01:00
parent bbae0e8fbb
commit 248966d63f
18 changed files with 595 additions and 122 deletions

View File

@@ -5,7 +5,7 @@ import (
"errors" "errors"
"github.com/google/uuid" "github.com/google/uuid"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"relay-server/helpers" "relay-server/utils"
) )
func CheckUserExists(db *sql.DB, username string) (bool, error) { func CheckUserExists(db *sql.DB, username string) (bool, error) {
@@ -14,7 +14,7 @@ func CheckUserExists(db *sql.DB, username string) (bool, error) {
var count int var count int
err := db.QueryRow(query, username).Scan(&count) err := db.QueryRow(query, username).Scan(&count)
if err != nil { if err != nil {
return false, helpers.NewError(helpers.ErrInternal, "Failed to check user existence", err) return false, utils.NewError(utils.ErrInternal, "Failed to check user existence", err)
} }
return count > 0, nil return count > 0, nil
@@ -30,7 +30,7 @@ func InsertUser(db *sql.DB, username string, passwordHash string) (string, error
var userID string var userID string
err := db.QueryRow(query, username, passwordHash).Scan(&userID) err := db.QueryRow(query, username, passwordHash).Scan(&userID)
if err != nil { if err != nil {
return "", helpers.NewError(helpers.ErrInternal, "Failed to create user", err) return "", utils.NewError(utils.ErrInternal, "Failed to create user", err)
} }
return userID, nil return userID, nil
@@ -46,9 +46,9 @@ func GetPasswordHash(db *sql.DB, username string) (string, error) {
err := db.QueryRow(query, username).Scan(&passwordHash) err := db.QueryRow(query, username).Scan(&passwordHash)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return "", helpers.NewError(helpers.ErrNotFound, "User not found", nil) return "", utils.NewError(utils.ErrNotFound, "User not found", nil)
} }
return "", helpers.NewError(helpers.ErrInternal, "Failed to get user credentials", err) return "", utils.NewError(utils.ErrInternal, "Failed to get user credentials", err)
} }
return passwordHash, nil return passwordHash, nil
@@ -64,9 +64,9 @@ func GetUserID(db *sql.DB, username string) (uuid.UUID, error) {
err := db.QueryRow(query, username).Scan(&userID) err := db.QueryRow(query, username).Scan(&userID)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return uuid.Nil, helpers.NewError(helpers.ErrNotFound, "User not found", nil) return uuid.Nil, utils.NewError(utils.ErrNotFound, "User not found", nil)
} }
return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to get user ID", err) return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to get user ID", err)
} }
return userID, nil return userID, nil
} }

View File

@@ -5,8 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"relay-server/helpers"
"relay-server/model" "relay-server/model"
"relay-server/utils"
"strings" "strings"
) )
@@ -18,9 +18,9 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error
).Scan(&conversationType) ).Scan(&conversationType)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return helpers.NewError(helpers.ErrNotFound, "no contacts found for this id", fmt.Errorf("no conversation found with id: %s", conversationID)) return utils.NewError(utils.ErrNotFound, "no contacts found for this id", fmt.Errorf("no conversation found with id: %s", conversationID))
} }
return helpers.NewError(helpers.ErrInternal, "Failed to check conversation", err) return utils.NewError(utils.ErrInternal, "Failed to check conversation", err)
} }
if conversationType == "group" { if conversationType == "group" {
@@ -30,15 +30,15 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error
conversationID, userID, conversationID, userID,
) )
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to delete contact", err) return utils.NewError(utils.ErrInternal, "Failed to delete contact", err)
} }
rowsAffected, err := res.RowsAffected() rowsAffected, err := res.RowsAffected()
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err)) return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err))
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return helpers.NewError(helpers.ErrNotFound, fmt.Sprintf("no matching contact found with conversation id: %s, user id: %s", conversationID, userID), nil) return utils.NewError(utils.ErrNotFound, fmt.Sprintf("no matching contact found with conversation id: %s, user id: %s", conversationID, userID), nil)
} }
// Delete from Memberships // Delete from Memberships
@@ -47,15 +47,15 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error
conversationID, userID, conversationID, userID,
) )
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to delete membership: %w", err)) return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to delete membership: %w", err))
} }
rowsAffected, err = res.RowsAffected() rowsAffected, err = res.RowsAffected()
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to verify membership deletion: %w", err)) return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to verify membership deletion: %w", err))
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return helpers.NewError(helpers.ErrNotFound, "No membership found", err) return utils.NewError(utils.ErrNotFound, "No membership found", err)
} }
} else { } else {
res, err := db.Exec( res, err := db.Exec(
@@ -63,15 +63,15 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error
userID, conversationID, userID, conversationID,
) )
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to delete contact", err) return utils.NewError(utils.ErrInternal, "Failed to delete contact", err)
} }
rowsAffected, err := res.RowsAffected() rowsAffected, err := res.RowsAffected()
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err)) return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err))
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return helpers.NewError(helpers.ErrNotFound, fmt.Sprintf("no matching contact found with user id: %s, conversation id: %s", userID, conversationID), nil) return utils.NewError(utils.ErrNotFound, fmt.Sprintf("no matching contact found with user id: %s, conversation id: %s", userID, conversationID), nil)
} }
} }
@@ -94,7 +94,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse
`, userID).Scan(&conversationID) `, userID).Scan(&conversationID)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation: %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation: %w", err))
} }
if conversationID == uuid.Nil { if conversationID == uuid.Nil {
@@ -104,7 +104,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse
RETURNING conversation_id; RETURNING conversation_id;
`).Scan(&conversationID) `).Scan(&conversationID)
if err != nil { if err != nil {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err))
} }
_, err = db.Exec(` _, err = db.Exec(`
@@ -113,7 +113,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse
ON CONFLICT (conversation_id, user_id) DO NOTHING; ON CONFLICT (conversation_id, user_id) DO NOTHING;
`, conversationID, userID) `, conversationID, userID)
if err != nil { if err != nil {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create membership: %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create membership: %w", err))
} }
} }
} else { } else {
@@ -128,7 +128,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse
`, userID, contactID).Scan(&conversationID) `, userID, contactID).Scan(&conversationID)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation %w", err))
} }
if conversationID == uuid.Nil { if conversationID == uuid.Nil {
@@ -138,7 +138,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse
RETURNING conversation_id; RETURNING conversation_id;
`).Scan(&conversationID) `).Scan(&conversationID)
if err != nil { if err != nil {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err))
} }
_, err = db.Exec(` _, err = db.Exec(`
@@ -147,7 +147,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse
ON CONFLICT (conversation_id, user_id) DO NOTHING; ON CONFLICT (conversation_id, user_id) DO NOTHING;
`, conversationID, userID, contactID) `, conversationID, userID, contactID)
if err != nil { if err != nil {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create memberships: %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create memberships: %w", err))
} }
} }
} }
@@ -186,9 +186,9 @@ func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (
`, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) `, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID)
if err == nil { if err == nil {
return &model.Contact{}, helpers.NewError(helpers.ErrInvalidInput, "Contact already exists", nil) return &model.Contact{}, utils.NewError(utils.ErrInvalidInput, "Contact already exists", nil)
} else if !errors.Is(err, sql.ErrNoRows) { } else if !errors.Is(err, sql.ErrNoRows) {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check contact existence: %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check contact existence: %w", err))
} }
// Insert new contact // Insert new contact
@@ -199,7 +199,7 @@ func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (
`, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) `, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID)
if err != nil { if err != nil {
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "Failed to create contact", err) return &model.Contact{}, utils.NewError(utils.ErrInternal, "Failed to create contact", err)
} }
return &contact, nil return &contact, nil
@@ -230,7 +230,7 @@ func GetLatestMessage(db *sql.DB, conversationId uuid.UUID) (*model.Contact, err
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return &model.Contact{}, nil return &model.Contact{}, nil
} }
return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "Failed to get latest message", fmt.Errorf("failed to get latest message: %w", err)) return &model.Contact{}, utils.NewError(utils.ErrInternal, "Failed to get latest message", fmt.Errorf("failed to get latest message: %w", err))
} }
return &latestMessage, nil return &latestMessage, nil
@@ -287,7 +287,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) {
rows, err := db.Query(contactsQuery, userID) rows, err := db.Query(contactsQuery, userID)
if err != nil { if err != nil {
return []*model.Contact{}, helpers.NewError(helpers.ErrInternal, "Failed to get contacts", fmt.Errorf("failed to get contacts: %w", err)) return []*model.Contact{}, utils.NewError(utils.ErrInternal, "Failed to get contacts", fmt.Errorf("failed to get contacts: %w", err))
} }
defer rows.Close() defer rows.Close()
@@ -296,7 +296,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) {
contact := &model.Contact{} contact := &model.Contact{}
err := rows.Scan(&contact.ID, &contact.UserID, &contact.Username, &contact.LastActive, &contact.ConversationID, &contact.Type, &contact.LastReadMessageID) err := rows.Scan(&contact.ID, &contact.UserID, &contact.Username, &contact.LastActive, &contact.ConversationID, &contact.Type, &contact.LastReadMessageID)
if err != nil { if err != nil {
return []*model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact: %w", err)) return []*model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact: %w", err))
} }
latestMessage, err := GetLatestMessage(db, contact.ConversationID) latestMessage, err := GetLatestMessage(db, contact.ConversationID)
@@ -313,7 +313,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) {
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return []*model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process contacts: %w", err)) return []*model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process contacts: %w", err))
} }
return contacts, nil return contacts, nil
@@ -329,7 +329,7 @@ func ContactSuggestion(db *sql.DB, contactUsername string) ([]string, error) {
rows, err := db.Query(query, "%"+strings.ToLower(contactUsername)+"%") rows, err := db.Query(query, "%"+strings.ToLower(contactUsername)+"%")
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return []string{}, helpers.NewError(helpers.ErrInternal, "Failed to get contact suggestions", fmt.Errorf("failed to get contact suggestions: %w", err)) return []string{}, utils.NewError(utils.ErrInternal, "Failed to get contact suggestions", fmt.Errorf("failed to get contact suggestions: %w", err))
} }
defer rows.Close() defer rows.Close()
@@ -338,12 +338,12 @@ func ContactSuggestion(db *sql.DB, contactUsername string) ([]string, error) {
var suggestion string var suggestion string
err := rows.Scan(&suggestion) err := rows.Scan(&suggestion)
if err != nil { if err != nil {
return []string{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact suggestion: %w", err)) return []string{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact suggestion: %w", err))
} }
suggestions = append(suggestions, suggestion) suggestions = append(suggestions, suggestion)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return []string{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process suggestions: %w", err)) return []string{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process suggestions: %w", err))
} }
return suggestions, nil return suggestions, nil
} }

View File

@@ -4,8 +4,8 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"relay-server/helpers"
"relay-server/model" "relay-server/model"
"relay-server/utils"
) )
func CreateGroup(db *sql.DB, groupName string, userID uuid.UUID) (uuid.UUID, error) { func CreateGroup(db *sql.DB, groupName string, userID uuid.UUID) (uuid.UUID, error) {
@@ -25,14 +25,14 @@ func CreateGroup(db *sql.DB, groupName string, userID uuid.UUID) (uuid.UUID, err
var groupID uuid.UUID var groupID uuid.UUID
err := db.QueryRow(createConversationQuery, groupName).Scan(&groupID) err := db.QueryRow(createConversationQuery, groupName).Scan(&groupID)
if err != nil { if err != nil {
return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to create group", fmt.Errorf("failed to create group: %w", err)) return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to create group", fmt.Errorf("failed to create group: %w", err))
} }
// Insert group admin (make user that created the group an admin) // Insert group admin (make user that created the group an admin)
var grantedAt string var grantedAt string
err = db.QueryRow(insertGroupAdminQuery, groupID, userID, userID).Scan(&grantedAt) err = db.QueryRow(insertGroupAdminQuery, groupID, userID, userID).Scan(&grantedAt)
if err != nil { if err != nil {
return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to create group", fmt.Errorf("failed to insert group admin: %w", err)) return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to create group", fmt.Errorf("failed to insert group admin: %w", err))
} }
// Add self as group member // Add self as group member
@@ -44,7 +44,7 @@ func CreateGroup(db *sql.DB, groupName string, userID uuid.UUID) (uuid.UUID, err
// Insert group contact // Insert group contact
_, err = InsertContactByID(db, userID, groupID) _, err = InsertContactByID(db, userID, groupID)
if err != nil { if err != nil {
return groupID, helpers.NewError(helpers.ErrInternal, "Failed to create group contact", fmt.Errorf("failed to insert group contact: %w", err)) return groupID, utils.NewError(utils.ErrInternal, "Failed to create group contact", fmt.Errorf("failed to insert group contact: %w", err))
} }
return groupID, nil return groupID, nil
@@ -60,7 +60,7 @@ func AddMemberToGroup(db *sql.DB, userID uuid.UUID, groupID uuid.UUID) (uuid.UUI
var memberID uuid.UUID var memberID uuid.UUID
err := db.QueryRow(query, groupID, userID).Scan(&memberID) err := db.QueryRow(query, groupID, userID).Scan(&memberID)
if err != nil { if err != nil {
return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to add member to group", fmt.Errorf("failed to add member to group: %w", err)) return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to add member to group", fmt.Errorf("failed to add member to group: %w", err))
} }
return memberID, nil return memberID, nil
} }
@@ -83,7 +83,7 @@ func GetMembers(db *sql.DB, groupID uuid.UUID) ([]*model.Member, error) {
rows, err := db.Query(query, groupID) rows, err := db.Query(query, groupID)
if err != nil { if err != nil {
return []*model.Member{}, helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("failed to get members: %w", err)) return []*model.Member{}, utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("failed to get members: %w", err))
} }
defer rows.Close() defer rows.Close()
@@ -92,13 +92,13 @@ func GetMembers(db *sql.DB, groupID uuid.UUID) ([]*model.Member, error) {
var member model.Member var member model.Member
err = rows.Scan(&member.UserID, &member.Username, &member.IsAdmin, &member.IsOwner) err = rows.Scan(&member.UserID, &member.Username, &member.IsAdmin, &member.IsOwner)
if err != nil { if err != nil {
return []*model.Member{}, helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("failed to scan member: %w", err)) return []*model.Member{}, utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("failed to scan member: %w", err))
} }
members = append(members, &member) members = append(members, &member)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return []*model.Member{}, helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("error iterating members: %w", err)) return []*model.Member{}, utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("error iterating members: %w", err))
} }
return members, nil return members, nil
@@ -113,7 +113,7 @@ func IsAdmin(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bool, erro
var count int var count int
err := db.QueryRow(query, userID, conversationID).Scan(&count) err := db.QueryRow(query, userID, conversationID).Scan(&count)
if err != nil { if err != nil {
return false, helpers.NewError(helpers.ErrInternal, "Failed to check admin status", fmt.Errorf("failed to check admin status: %w", err)) return false, utils.NewError(utils.ErrInternal, "Failed to check admin status", fmt.Errorf("failed to check admin status: %w", err))
} }
return count > 0, nil return count > 0, nil
} }
@@ -127,7 +127,83 @@ func IsMember(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bool, err
var count int var count int
err := db.QueryRow(query, userID, conversationID).Scan(&count) err := db.QueryRow(query, userID, conversationID).Scan(&count)
if err != nil { if err != nil {
return false, helpers.NewError(helpers.ErrInternal, "Failed to check membership", fmt.Errorf("failed to check membership: %w", err)) return false, utils.NewError(utils.ErrInternal, "Failed to check membership", fmt.Errorf("failed to check membership: %w", err))
} }
return count > 0, nil return count > 0, nil
} }
func GetUserConversations(db *sql.DB, userID string) ([]string, error) {
query := `
SELECT DISTINCT m.conversation_id
FROM Memberships m
JOIN Conversations c ON m.conversation_id = c.conversation_id
WHERE m.user_id = $1
AND c.conversation_type = 'group';
`
var conversations []string
rows, err := db.Query(query, userID)
if err != nil {
return []string{}, fmt.Errorf("failed to get user conversations: %w", err)
}
defer rows.Close()
for rows.Next() {
var conversationID string
err = rows.Scan(&conversationID)
if err != nil {
return []string{}, fmt.Errorf("failed to scan conversation: %w", err)
}
conversations = append(conversations, conversationID)
}
if err := rows.Err(); err != nil {
return []string{}, fmt.Errorf("error iterating conversations: %w", err)
}
return conversations, nil
}
func RemoveUserFromGroup(db *sql.DB, userID uuid.UUID, groupID uuid.UUID) (string, error) {
removeUserFromGroupQuery := `
DELETE FROM Memberships
WHERE conversation_id = $1 AND user_id = $2;
`
isOwner, err := IsGroupOwner(db, userID, groupID)
if err != nil {
return "Failed to remove user from group", fmt.Errorf("failed to check if user is group owner: %w", err)
}
if !isOwner {
return "Cannot remove group owner", nil
}
res, err := db.Exec(removeUserFromGroupQuery, groupID, userID)
if err != nil {
return "Failed to remove user from group", fmt.Errorf("failed to remove user from group: %w", err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return "Failed to remove user from group", fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return "User is not a member of the group", nil
}
return "Successfully removed user from group", nil
}
func IsGroupOwner(db *sql.DB, userID uuid.UUID, groupID uuid.UUID) (bool, error) {
query := `
SELECT EXISTS (
SELECT 1 FROM GroupAdmins
WHERE conversation_id = $1
AND user_id = $2
AND is_owner = true
) AS is_owner;
`
var isOwner bool
err := db.QueryRow(query, groupID, userID).Scan(&isOwner)
if err != nil {
return false, fmt.Errorf("failed to check if user is group owner: %w", err)
}
return isOwner, nil
}

View File

@@ -4,8 +4,8 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"relay-server/helpers"
"relay-server/model" "relay-server/model"
"relay-server/utils"
) )
func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit int, cursor int) ([]*model.Message, error) { func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit int, cursor int) ([]*model.Message, error) {
@@ -33,20 +33,20 @@ func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit i
` `
rows, err := db.Query(query, conversationID, cursor, limit) rows, err := db.Query(query, conversationID, cursor, limit)
if err != nil { if err != nil {
return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err))
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
message := &model.Message{} message := &model.Message{}
err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) err = rows.Scan(&message.ID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender)
if err != nil { if err != nil {
return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err))
} }
messages = append(messages, message) messages = append(messages, message)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) return []*model.Message{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err))
} }
} else { } else {
@@ -65,20 +65,20 @@ func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit i
` `
rows, err := db.Query(query, conversationID, cursor, limit) rows, err := db.Query(query, conversationID, cursor, limit)
if err != nil { if err != nil {
return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err))
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
message := &model.Message{} message := &model.Message{}
err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) err = rows.Scan(&message.ID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender)
if err != nil { if err != nil {
return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err))
} }
messages = append(messages, message) messages = append(messages, message)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) return []*model.Message{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err))
} }
} }
if cursor != 0 { if cursor != 0 {
@@ -102,10 +102,10 @@ func checkMembership(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bo
var isMember bool var isMember bool
err := db.QueryRow(query, userID, conversationID).Scan(&isMember) err := db.QueryRow(query, userID, conversationID).Scan(&isMember)
if err != nil { if err != nil {
return false, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check membership: %w", err)) return false, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check membership: %w", err))
} }
if !isMember { if !isMember {
return false, helpers.NewError(helpers.ErrForbidden, "You are member of the conversation", nil) return false, utils.NewError(utils.ErrForbidden, "You are member of the conversation", nil)
} }
return isMember, nil return isMember, nil
} }
@@ -121,7 +121,7 @@ func DeleteMessage(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, messa
var messageOwnerID uuid.UUID var messageOwnerID uuid.UUID
err := db.QueryRow(checkMessageOwnershipQuery, messageID).Scan(&messageOwnerID) err := db.QueryRow(checkMessageOwnershipQuery, messageID).Scan(&messageOwnerID)
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to delete message", fmt.Errorf("failed to check message ownership: %w", err)) return utils.NewError(utils.ErrInternal, "Failed to delete message", fmt.Errorf("failed to check message ownership: %w", err))
} }
var isSelfMessage bool var isSelfMessage bool
@@ -134,20 +134,51 @@ func DeleteMessage(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, messa
return err return err
} }
if !isSelfMessage && !isAdmin { if !isSelfMessage && !isAdmin {
return helpers.NewError(helpers.ErrForbidden, "You don't have permissions to delete that message ", nil) return utils.NewError(utils.ErrForbidden, "You don't have permissions to delete that message ", nil)
} }
row, err := db.Exec(deleteMessageQuery, messageID) row, err := db.Exec(deleteMessageQuery, messageID)
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to delete message", fmt.Errorf("failed to delete message: %w", err)) return utils.NewError(utils.ErrInternal, "Failed to delete message", fmt.Errorf("failed to delete message: %w", err))
} }
rowsAffected, err := row.RowsAffected() rowsAffected, err := row.RowsAffected()
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to delete message", fmt.Errorf("failed to get rows affected: %w", err)) return utils.NewError(utils.ErrInternal, "Failed to delete message", fmt.Errorf("failed to get rows affected: %w", err))
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return helpers.NewError(helpers.ErrNotFound, "Message not found", nil) return utils.NewError(utils.ErrNotFound, "Message not found", nil)
} }
return nil return nil
} }
func InsertMessage(db *sql.DB, senderID uuid.UUID, conversationID uuid.UUID, message string, attachmentUrls []string) (*model.Message, error) {
isMember, err := checkMembership(db, senderID, conversationID)
if err != nil {
return &model.Message{}, fmt.Errorf("failed to check membership: %w", err)
}
if !isMember {
return &model.Message{}, fmt.Errorf("user is not a member of the conversation")
}
query := `
INSERT INTO Messages (conversation_id, user_id, content, attachment_urls)
VALUES ($1, $2, $3, $4)
RETURNING message_id, content AS message, sent_at, attachment_urls, user_id AS sender_id, conversation_id;
`
var msg model.Message
err = db.QueryRow(query, conversationID, senderID, message, attachmentUrls).Scan(
&msg.ID,
&msg.Message,
&msg.SentAt,
&msg.AttachmentUrl,
&msg.SenderID,
&msg.ConversationID,
)
if err != nil {
return &model.Message{}, fmt.Errorf("failed to insert message: %w", err)
}
return &msg, nil
}

11
go.mod
View File

@@ -12,16 +12,25 @@ require (
require ( require (
github.com/MicahParks/keyfunc/v2 v2.1.0 // indirect github.com/MicahParks/keyfunc/v2 v2.1.0 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect github.com/andybalholm/brotli v1.1.1 // indirect
github.com/fasthttp/websocket v1.5.12 // indirect
github.com/gofiber/contrib/jwt v1.0.10 // indirect github.com/gofiber/contrib/jwt v1.0.10 // indirect
github.com/gofiber/contrib/socketio v1.1.4 // indirect
github.com/gofiber/contrib/websocket v1.3.3 // indirect
github.com/gofrs/uuid v4.0.0+incompatible // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/gomodule/redigo v1.8.4 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/googollee/go-socket.io v1.7.0 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/compress v1.17.11 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.58.0 // indirect github.com/valyala/fasthttp v1.58.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect
golang.org/x/sys v0.29.0 // indirect golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.30.0 // indirect
) )

28
go.sum
View File

@@ -4,14 +4,30 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE=
github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg=
github.com/gofiber/contrib/jwt v1.0.10 h1:/ilGepl6i0Bntl0Zcd+lAzagY8BiS1+fEiAj32HMApk= github.com/gofiber/contrib/jwt v1.0.10 h1:/ilGepl6i0Bntl0Zcd+lAzagY8BiS1+fEiAj32HMApk=
github.com/gofiber/contrib/jwt v1.0.10/go.mod h1:1qBENE6sZ6PPT4xIpBzx1VxeyROQO7sj48OlM1I9qdU= github.com/gofiber/contrib/jwt v1.0.10/go.mod h1:1qBENE6sZ6PPT4xIpBzx1VxeyROQO7sj48OlM1I9qdU=
github.com/gofiber/contrib/socketio v1.1.4 h1:XoS4N4yvbVJeFOfzFOiHKRGn++Vax+doQhJLZEMnc5M=
github.com/gofiber/contrib/socketio v1.1.4/go.mod h1:ZqUgo7SYEp7TJFH+BwpW2F0Xo4BdIYLLPkbxxnK3KnY=
github.com/gofiber/contrib/websocket v1.3.3 h1:R6DlDKieGPMiDrqYNyobsHbvjqvxMHeCj/lLaca4jg8=
github.com/gofiber/contrib/websocket v1.3.3/go.mod h1:07u6QGMsvX+sx7iGNCl5xhzuUVArWwLQ3tBIH24i+S8=
github.com/gofiber/fiber/v2 v2.52.6 h1:Rfp+ILPiYSvvVuIPvxrBns+HJp8qGLDnLJawAu27XVI= github.com/gofiber/fiber/v2 v2.52.6 h1:Rfp+ILPiYSvvVuIPvxrBns+HJp8qGLDnLJawAu27XVI=
github.com/gofiber/fiber/v2 v2.52.6/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/fiber/v2 v2.52.6/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/gomodule/redigo v1.8.4 h1:Z5JUg94HMTR1XpwBaSH4vq3+PNSIykBLxMdglbw10gg=
github.com/gomodule/redigo v1.8.4/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googollee/go-socket.io v1.7.0 h1:ODcQSAvVIPvKozXtUGuJDV3pLwdpBLDs1Uoq/QHIlY8=
github.com/googollee/go-socket.io v1.7.0/go.mod h1:0vGP8/dXR9SZUMMD4+xxaGo/lohOw3YWMh2WRiWeKxg=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
@@ -29,10 +45,15 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc=
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
@@ -46,9 +67,16 @@ golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -9,7 +9,7 @@ import (
"os" "os"
"relay-server/config" "relay-server/config"
"relay-server/database" "relay-server/database"
"relay-server/helpers" "relay-server/utils"
"time" "time"
) )
@@ -21,21 +21,21 @@ func Signup(c *fiber.Ctx) error {
u := new(SignupStruct) u := new(SignupStruct)
if err := c.BodyParser(u); err != nil { if err := c.BodyParser(u); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err)
} }
// Validate input // Validate input
if u.Username == "" { if u.Username == "" {
return helpers.NewError(helpers.ErrInvalidInput, "Username is empty", nil) return utils.NewError(utils.ErrInvalidInput, "Username is empty", nil)
} }
if u.Password == "" { if u.Password == "" {
return helpers.NewError(helpers.ErrInvalidInput, "Password is empty", nil) return utils.NewError(utils.ErrInvalidInput, "Password is empty", nil)
} }
if !helpers.IsValidPassword(u.Password) { if !utils.IsValidPassword(u.Password) {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid password", nil) return utils.NewError(utils.ErrInvalidInput, "Invalid password", nil)
} }
if !helpers.IsValidUsername(u.Username) { if !utils.IsValidUsername(u.Username) {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid username", nil) return utils.NewError(utils.ErrInvalidInput, "Invalid username", nil)
} }
// Check if user exists // Check if user exists
@@ -44,13 +44,13 @@ func Signup(c *fiber.Ctx) error {
return err return err
} }
if exist { if exist {
return helpers.NewError(helpers.ErrInvalidInput, "User already exists", nil) return utils.NewError(utils.ErrInvalidInput, "User already exists", nil)
} }
// Create password hash // Create password hash
passwordHash, err := bcrypt.GenerateFromPassword([]byte(u.Password), config.BCRYPT_COST) passwordHash, err := bcrypt.GenerateFromPassword([]byte(u.Password), config.BCRYPT_COST)
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to generate password hash: %w", err)) return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to generate password hash: %w", err))
} }
// Insert user // Insert user
@@ -66,7 +66,7 @@ func Signup(c *fiber.Ctx) error {
}) })
signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET")))
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to generate auth token", err) return utils.NewError(utils.ErrInternal, "Failed to generate auth token", err)
} }
// Set token cookie // Set token cookie
@@ -91,21 +91,21 @@ func Login(c *fiber.Ctx) error {
u := new(loginStruct) u := new(loginStruct)
if err := c.BodyParser(u); err != nil { if err := c.BodyParser(u); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err)
} }
// Validate input // Validate input
if u.Username == "" { if u.Username == "" {
return helpers.NewError(helpers.ErrInvalidInput, "Username is empty", nil) return utils.NewError(utils.ErrInvalidInput, "Username is empty", nil)
} }
if u.Password == "" { if u.Password == "" {
return helpers.NewError(helpers.ErrInvalidInput, "Password is empty", nil) return utils.NewError(utils.ErrInvalidInput, "Password is empty", nil)
} }
if !helpers.IsValidUsername(u.Username) { if !utils.IsValidUsername(u.Username) {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid username", nil) return utils.NewError(utils.ErrInvalidInput, "Invalid username", nil)
} }
if !helpers.IsValidPassword(u.Password) { if !utils.IsValidPassword(u.Password) {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid password", nil) return utils.NewError(utils.ErrInvalidInput, "Invalid password", nil)
} }
// Check if user exists // Check if user exists
@@ -114,7 +114,7 @@ func Login(c *fiber.Ctx) error {
return err return err
} }
if !exist { if !exist {
return helpers.NewError(helpers.ErrNotFound, "User does not exist", nil) return utils.NewError(utils.ErrNotFound, "User does not exist", nil)
} }
// Verify password // Verify password
@@ -123,7 +123,7 @@ func Login(c *fiber.Ctx) error {
return err return err
} }
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(u.Password)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(u.Password)); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid password", nil) return utils.NewError(utils.ErrInvalidInput, "Invalid password", nil)
} }
// Get user ID // Get user ID
@@ -139,7 +139,7 @@ func Login(c *fiber.Ctx) error {
}) })
signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET")))
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to generate token", err) return utils.NewError(utils.ErrInternal, "Failed to generate token", err)
} }
// Set token cookie // Set token cookie
@@ -159,13 +159,13 @@ func Login(c *fiber.Ctx) error {
func ValidateToken(c *fiber.Ctx) error { func ValidateToken(c *fiber.Ctx) error {
username, ok := c.Locals("username").(string) username, ok := c.Locals("username").(string)
if !ok { if !ok {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing username", fmt.Errorf("missing username: %v", c.Locals("username"))) return utils.NewError(utils.ErrInvalidInput, "Invalid token: missing username", fmt.Errorf("missing username: %v", c.Locals("username")))
} }
userIDVal := c.Locals("userID") userIDVal := c.Locals("userID")
userID, ok := userIDVal.(uuid.UUID) userID, ok := userIDVal.(uuid.UUID)
if !ok { if !ok {
return helpers.NewError(helpers.ErrUnauthorized, "unauthorized", fmt.Errorf("missing/invalid userID type: %T, value: %v\n", userIDVal, userIDVal)) return utils.NewError(utils.ErrUnauthorized, "unauthorized", fmt.Errorf("missing/invalid userID type: %T, value: %v\n", userIDVal, userIDVal))
} }
return c.Status(fiber.StatusOK).JSON(fiber.Map{ return c.Status(fiber.StatusOK).JSON(fiber.Map{

View File

@@ -4,7 +4,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"relay-server/database" "relay-server/database"
"relay-server/helpers" "relay-server/utils"
) )
func DeleteContact(c *fiber.Ctx) error { func DeleteContact(c *fiber.Ctx) error {
@@ -15,14 +15,14 @@ func DeleteContact(c *fiber.Ctx) error {
p := new(params) p := new(params)
if err := c.ParamsParser(p); err != nil { if err := c.ParamsParser(p); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) return utils.NewError(utils.ErrInvalidInput, "Invalid params", err)
} }
if p.ContactID == uuid.Nil { if p.ContactID == uuid.Nil {
return helpers.NewError(helpers.ErrInvalidInput, "contact ID is empty", nil) return utils.NewError(utils.ErrInvalidInput, "contact ID is empty", nil)
} }
if p.ConversationID == uuid.Nil { if p.ConversationID == uuid.Nil {
return helpers.NewError(helpers.ErrInvalidInput, "conversation ID is empty", nil) return utils.NewError(utils.ErrInvalidInput, "conversation ID is empty", nil)
} }
err := database.DeleteContact(database.DB, p.ContactID, p.ConversationID) err := database.DeleteContact(database.DB, p.ContactID, p.ConversationID)
@@ -42,15 +42,15 @@ func InsertContact(c *fiber.Ctx) error {
p := new(params) p := new(params)
if err := c.ParamsParser(p); err != nil { if err := c.ParamsParser(p); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) return utils.NewError(utils.ErrInvalidInput, "Invalid params", err)
} }
if p.ContactUsername == "" { if p.ContactUsername == "" {
return helpers.NewError(helpers.ErrInvalidInput, "contact username is empty", nil) return utils.NewError(utils.ErrInvalidInput, "contact username is empty", nil)
} }
if !helpers.IsValidUsername(p.ContactUsername) { if !utils.IsValidUsername(p.ContactUsername) {
return helpers.NewError(helpers.ErrInvalidInput, "invalid username", nil) return utils.NewError(utils.ErrInvalidInput, "invalid username", nil)
} }
contactID, err := database.GetUserID(database.DB, p.ContactUsername) contactID, err := database.GetUserID(database.DB, p.ContactUsername)
@@ -84,10 +84,10 @@ func GetContactSuggestions(c *fiber.Ctx) error {
p := new(params) p := new(params)
if err := c.ParamsParser(p); err != nil { if err := c.ParamsParser(p); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) return utils.NewError(utils.ErrInvalidInput, "Invalid params", err)
} }
if p.ContactUsername == "" { if p.ContactUsername == "" {
return helpers.NewError(helpers.ErrInvalidInput, "contact username is empty", nil) return utils.NewError(utils.ErrInvalidInput, "contact username is empty", nil)
} }
suggestions, err := database.ContactSuggestion(database.DB, p.ContactUsername) suggestions, err := database.ContactSuggestion(database.DB, p.ContactUsername)

View File

@@ -6,7 +6,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"log" "log"
"relay-server/database" "relay-server/database"
"relay-server/helpers" "relay-server/utils"
) )
func CreateGroup(c *fiber.Ctx) error { func CreateGroup(c *fiber.Ctx) error {
@@ -18,11 +18,11 @@ func CreateGroup(c *fiber.Ctx) error {
var req createGroupRequest var req createGroupRequest
if err := c.BodyParser(&req); err != nil { if err := c.BodyParser(&req); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err)
} }
if req.GroupName == "" { if req.GroupName == "" {
return helpers.NewError(helpers.ErrInvalidInput, "Group name is empty", nil) return utils.NewError(utils.ErrInvalidInput, "Group name is empty", nil)
} }
groupID, err := database.CreateGroup(database.DB, req.GroupName, userID) groupID, err := database.CreateGroup(database.DB, req.GroupName, userID)
@@ -41,13 +41,13 @@ func AddMemberToGroup(c *fiber.Ctx) error {
} }
var req addMemberToGroupRequest var req addMemberToGroupRequest
if err := c.BodyParser(&req); err != nil { if err := c.BodyParser(&req); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err)
} }
if req.GroupID == uuid.Nil { if req.GroupID == uuid.Nil {
return helpers.NewError(helpers.ErrInvalidInput, "Group ID is empty", nil) return utils.NewError(utils.ErrInvalidInput, "Group ID is empty", nil)
} }
if req.UserID == uuid.Nil { if req.UserID == uuid.Nil {
return helpers.NewError(helpers.ErrInvalidInput, "User ID is empty", nil) return utils.NewError(utils.ErrInvalidInput, "User ID is empty", nil)
} }
isAdmin, err := database.IsAdmin(database.DB, req.UserID, req.GroupID) isAdmin, err := database.IsAdmin(database.DB, req.UserID, req.GroupID)
@@ -55,7 +55,7 @@ func AddMemberToGroup(c *fiber.Ctx) error {
return err return err
} }
if !isAdmin { if !isAdmin {
return helpers.NewError(helpers.ErrUnauthorized, "You are not a group administrator", nil) return utils.NewError(utils.ErrUnauthorized, "You are not a group administrator", nil)
} }
_, err = database.AddMemberToGroup(database.DB, req.GroupID, req.UserID) _, err = database.AddMemberToGroup(database.DB, req.GroupID, req.UserID)
@@ -75,15 +75,15 @@ func GetMembers(c *fiber.Ctx) error {
var p params var p params
if err := c.ParamsParser(&p); err != nil { if err := c.ParamsParser(&p); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) return utils.NewError(utils.ErrInvalidInput, "Invalid params", err)
} }
isMember, err := database.IsMember(database.DB, p.GroupID, c.Locals("userID").(uuid.UUID)) isMember, err := database.IsMember(database.DB, c.Locals("userID").(uuid.UUID), p.GroupID)
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("failed to check if user is a member: %w", err)) return utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("failed to check if user is a member: %w", err))
} }
if !isMember { if !isMember {
return helpers.NewError(helpers.ErrForbidden, "You are not a member of this group", nil) return utils.NewError(utils.ErrForbidden, "You are not a member of this group", nil)
} }
members, err := database.GetMembers(database.DB, p.GroupID) members, err := database.GetMembers(database.DB, p.GroupID)
@@ -91,5 +91,5 @@ func GetMembers(c *fiber.Ctx) error {
return err return err
} }
return c.Status(fiber.StatusOK).JSON(fiber.Map{"members": members}) return c.Status(fiber.StatusOK).JSON(members)
} }

View File

@@ -4,7 +4,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"relay-server/database" "relay-server/database"
"relay-server/helpers" "relay-server/utils"
) )
func GetMessages(c *fiber.Ctx) error { func GetMessages(c *fiber.Ctx) error {
@@ -20,13 +20,13 @@ func GetMessages(c *fiber.Ctx) error {
p := new(params) p := new(params)
q := new(query) q := new(query)
if err := c.ParamsParser(p); err != nil { if err := c.ParamsParser(p); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) return utils.NewError(utils.ErrInvalidInput, "Invalid params", err)
} }
if err := c.QueryParser(q); err != nil { if err := c.QueryParser(q); err != nil {
return helpers.NewError(helpers.ErrInvalidInput, "Invalid query", err) return utils.NewError(utils.ErrInvalidInput, "Invalid query", err)
} }
if p.conversationID == uuid.Nil { if p.conversationID == uuid.Nil {
return helpers.NewError(helpers.ErrInvalidInput, "conversation ID is empty", nil) return utils.NewError(utils.ErrInvalidInput, "conversation ID is empty", nil)
} }
messages, err := database.GetMessages(database.DB, userID, p.conversationID, q.limit, q.cursor) messages, err := database.GetMessages(database.DB, userID, p.conversationID, q.limit, q.cursor)

14
main.go
View File

@@ -1,17 +1,27 @@
package main package main
import ( import (
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"log" "log"
"relay-server/database" "relay-server/database"
"relay-server/helpers"
"relay-server/router" "relay-server/router"
"relay-server/utils"
) )
func main() { func main() {
app := fiber.New(fiber.Config{ app := fiber.New(fiber.Config{
ErrorHandler: helpers.ErrorHandler, ErrorHandler: utils.ErrorHandler,
}) })
app.Use(func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
c.Locals("allowed", true)
return c.Next()
}
return fiber.ErrUpgradeRequired
})
db, err := database.Init() db, err := database.Init()
if err != nil { if err != nil {
log.Fatal("Failed to initialize database") log.Fatal("Failed to initialize database")

View File

@@ -7,8 +7,8 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"os" "os"
"relay-server/helpers"
"relay-server/model" "relay-server/model"
"relay-server/utils"
) )
func Protected() fiber.Handler { func Protected() fiber.Handler {
@@ -22,7 +22,7 @@ func Protected() fiber.Handler {
claims := user.Claims.(*model.UserClaims) claims := user.Claims.(*model.UserClaims)
userID, err := uuid.Parse(claims.UserID) userID, err := uuid.Parse(claims.UserID)
if err != nil { if err != nil {
return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to parse user ID: %w", err)) return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to parse user ID: %w", err))
} }
c.Locals("userID", userID) c.Locals("userID", userID)
c.Locals("username", claims.Username) c.Locals("username", claims.Username)

View File

@@ -30,11 +30,13 @@ type ContactSuggestion struct {
} }
type Message struct { type Message struct {
MessageID int `json:"message_id"` ID int `json:"message_id"`
Message string `json:"message"` Message string `json:"message"`
SentAt time.Time `json:"sent_at"` SentAt time.Time `json:"sent_at"`
Sender string `json:"sender"` Sender string `json:"sender"`
AttachmentUrl *string `json:"attachment_url"` SenderID uuid.UUID `json:"sender_id"`
AttachmentUrl *string `json:"attachment_url"`
ConversationID uuid.UUID `json:"conversation_id"`
} }
type CreateGroupResponse struct { type CreateGroupResponse struct {

View File

@@ -8,6 +8,7 @@ import (
) )
func SetupRoutes(app *fiber.App) { func SetupRoutes(app *fiber.App) {
app.Get("/", middleware.Protected(), func(c *fiber.Ctx) error { app.Get("/", middleware.Protected(), func(c *fiber.Ctx) error {
return c.SendString("Hello, World!") return c.SendString("Hello, World!")
}) })
@@ -36,4 +37,8 @@ func SetupRoutes(app *fiber.App) {
groups.Post("/create", handlers.CreateGroup) groups.Post("/create", handlers.CreateGroup)
groups.Post("/addMember", handlers.AddMemberToGroup) groups.Post("/addMember", handlers.AddMemberToGroup)
groups.Get("/getMembers/:groupID", handlers.GetMembers) groups.Get("/getMembers/:groupID", handlers.GetMembers)
// Socket group
socket := chat.Group("/ws", middleware.Protected(), logger.New())
socket.Get("/:id")
} }

258
socket/socket.go Normal file
View File

@@ -0,0 +1,258 @@
package socket
import (
"fmt"
"github.com/google/uuid"
"log"
"relay-server/database"
"relay-server/utils"
socketio "github.com/googollee/go-socket.io"
)
// Message represents the chat message structure
type Message struct {
Content string `json:"message"`
Recipient string `json:"recipient"`
RecipientID uuid.UUID `json:"recipient_id"`
AttachmentURLs []string `json:"attachment_urls"`
}
// DeleteMessageData represents the message deletion structure
type DeleteMessageData struct {
ConversationID uuid.UUID `json:"conversation_id"`
MessageID int `json:"message_id"`
}
// GroupUserData represents group-user operations data
type GroupUserData struct {
GroupID uuid.UUID `json:"group_id"`
UserID uuid.UUID `json:"user_id"`
}
// MessageReadData represents message read status data
type MessageReadData struct {
ConversationID uuid.UUID `json:"conversation_id"`
MessageID int `json:"message_id"`
}
// SocketResponse represents a standard socket response
type SocketResponse struct {
Status string `json:"status"`
Message string `json:"message"`
}
// InitializeSocket sets up and configures the Socket.IO server
func InitializeSocket() (*socketio.Server, error) {
server := socketio.NewServer(nil)
// Middleware for authentication
server.OnConnect("/", func(s socketio.Conn) error {
token := s.RemoteHeader().Get("Authorization")
if token == "" {
log.Println("(socket) Not logged in")
return fmt.Errorf("not logged in")
}
c, err := utils.ValidateToken(token)
if err != nil {
log.Printf("(socket) Token verification failed: %v\n", err)
return fmt.Errorf("invalid token")
}
if !utils.IsValidUsername(c.Username) {
log.Println("(socket) Invalid username")
return fmt.Errorf("invalid username")
}
s.SetContext(map[string]interface{}{
"username": c.Username,
"user_id": c.UserID,
})
log.Printf("(socket) socket id: %s, username: %s, user_id: %s\n", s.ID(), c.Username, c.UserID)
return nil
})
// Handle connection
server.OnEvent("/", "connection", func(s socketio.Conn) {
ctx := s.Context().(map[string]interface{})
username := ctx["username"].(string)
userID := ctx["user_id"].(string)
if !utils.IsValidUsername(username) {
s.Close()
return
}
conversations, err := database.GetUserConversations(database.DB, userID)
if err != nil {
log.Printf("(socket) Failed to get user conversations: %v\n", err)
return
}
// Join all conversations
for _, conv := range conversations {
s.Join(conv)
}
s.Join(userID) // Join user's personal room
log.Printf("User: %s joined to: %v\n", username, conversations)
})
// Handle chat message
server.OnEvent("/", "chat message", func(s socketio.Conn, msg Message) SocketResponse {
ctx := s.Context().(map[string]interface{})
username := ctx["username"].(string)
userIDstr := ctx["user_id"].(string)
conversationIDstr := ctx["recipient_id"].(string)
userID, err := uuid.Parse(userIDstr)
if err != nil {
return SocketResponse{Status: "error", Message: "Invalid user id"}
}
conversationID, err := uuid.Parse(conversationIDstr)
if err != nil {
return SocketResponse{Status: "error", Message: "Invalid conversation id"}
}
if msg.Content == "" && len(msg.AttachmentURLs) == 0 {
return SocketResponse{Status: "error", Message: "No message or attachment provided"}
}
if msg.Recipient == "" {
return SocketResponse{Status: "error", Message: "No recipient provided"}
}
insertedMsg, err := database.InsertMessage(database.DB, userID, conversationID, msg.Content, msg.AttachmentURLs)
if err != nil {
log.Printf("(socket) Failed to insert message: %v\n", err)
return SocketResponse{Status: "error", Message: "Failed to insert message"}
}
// Emit message to recipients
server.BroadcastToRoom("", msg.Recipient, "chat message", map[string]interface{}{
"sender": username,
"message": insertedMsg.Message,
"attachment_urls": msg.AttachmentURLs,
"recipient": msg.Recipient,
"message_id": insertedMsg.ID,
"sender_id": userID,
"sent_at": insertedMsg.SentAt,
"conversation_id": insertedMsg.ConversationID,
})
return SocketResponse{Status: "ok", Message: "Received message"}
})
// Handle delete message
server.OnEvent("/", "delete message", func(s socketio.Conn, data DeleteMessageData) SocketResponse {
ctx := s.Context().(map[string]interface{})
userIDstr := ctx["user_id"].(string)
userID, err := uuid.Parse(userIDstr)
if err != nil {
return SocketResponse{Status: "error", Message: "Invalid user id"}
}
if data.MessageID == 0 {
return SocketResponse{Status: "error", Message: "No message id provided"}
}
if data.ConversationID == uuid.Nil {
return SocketResponse{Status: "error", Message: "No conversation id provided"}
}
err = database.DeleteMessage(database.DB, userID, data.ConversationID, data.MessageID)
if err != nil {
return SocketResponse{Status: "error", Message: err.Error()}
}
conversationIDstr := data.ConversationID.String()
server.BroadcastToRoom("", conversationIDstr, "delete message", data)
return SocketResponse{Status: "ok", Message: "Successfully deleted message"}
})
// Handle remove user from group
server.OnEvent("/", "remove user from group", func(s socketio.Conn, data GroupUserData) SocketResponse {
if data.GroupID == uuid.Nil {
return SocketResponse{Status: "error", Message: "No group id provided"}
}
if data.UserID == uuid.Nil {
return SocketResponse{Status: "error", Message: "No user id provided"}
}
err := removeUserFromGroupById(data.GroupID, data.UserID)
if err != nil {
return SocketResponse{Status: "error", Message: err.Error()}
}
// Remove user from room
sockets := server.Sockets(data.GroupID)
for _, socket := range sockets {
if socket.Context().(map[string]interface{})["user_id"] == data.UserID {
socket.Leave(data.GroupID)
}
}
server.BroadcastToRoom("", data.GroupID, "left group", data)
server.BroadcastToRoom("", data.UserID, "left group", data)
return SocketResponse{Status: "ok", Message: "Successfully removed user from group"}
})
// Handle administrator operations
server.OnEvent("/", "added administrator", func(s socketio.Conn, data GroupUserData) SocketResponse {
ctx := s.Context().(map[string]interface{})
userID := ctx["user_id"].(string)
if data.GroupID == "" {
return SocketResponse{Status: "error", Message: "No conversation id provided"}
}
if data.UserID == "" {
return SocketResponse{Status: "error", Message: "No user id provided"}
}
isAdmin, err := isAdmin(userID, data.GroupID)
if err != nil || !isAdmin {
return SocketResponse{Status: "error", Message: "You are not an administrator"}
}
err = addAdministrator(data.GroupID, data.UserID, userID)
if err != nil {
return SocketResponse{Status: "error", Message: err.Error()}
}
server.BroadcastToRoom("", data.GroupID, "added administrator", data)
return SocketResponse{Status: "ok", Message: "Successfully added administrator"}
})
// Handle message read status
server.OnEvent("/", "message read", func(s socketio.Conn, data MessageReadData) {
ctx := s.Context().(map[string]interface{})
userID := ctx["user_id"].(string)
if data.ConversationID == "" || data.MessageID == "" {
return
}
err := updateContactStatus(userID, data.ConversationID, data.MessageID)
if err != nil {
log.Printf("Failed to update message read status: %v\n", err)
}
})
// Handle disconnection
server.OnDisconnect("/", func(s socketio.Conn, reason string) {
log.Printf("(socket) %s disconnected due to: %s\n", s.ID(), reason)
})
return server, nil
}
// The following functions would need to be implemented according to your specific needs:
// - verifyJwtToken
// - isValidUsername
// - getConversationsForUser
// - insertMessage
// - deleteMessage
// - removeUserFromGroupById
// - isAdmin
// - addAdministrator
// - updateContactStatus

View File

@@ -1,4 +1,4 @@
package helpers package utils
import ( import (
"errors" "errors"

View File

@@ -1,4 +1,4 @@
package helpers package utils
import ( import (
"regexp" "regexp"

54
utils/token.go Normal file
View File

@@ -0,0 +1,54 @@
package utils
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"os"
"relay-server/model"
)
func GenerateToken(userID uuid.UUID, username string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": userID,
"username": username,
})
signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET")))
if err != nil {
return "", NewError(ErrInternal, "Failed to generate token", err)
}
return signedToken, nil
}
func ValidateToken(tokenString string) (*model.UserClaims, error) {
secretKey := os.Getenv("JWT_SECRET")
// Parse the token
token, err := jwt.ParseWithClaims(tokenString, &model.UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %v", err)
}
// Check if the token is valid
if !token.Valid {
return nil, errors.New("invalid token")
}
// Type assert the claims
claims, ok := token.Claims.(*model.UserClaims)
if !ok {
return nil, errors.New("failed to parse claims")
}
return claims, nil
}