feat: Security
This commit is contained in:
120
server/src/websocket/handler-auth.spec.ts
Normal file
120
server/src/websocket/handler-auth.spec.ts
Normal file
@@ -0,0 +1,120 @@
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
vi
|
||||
} from 'vitest';
|
||||
import { WebSocket } from 'ws';
|
||||
import { connectedUsers } from './state';
|
||||
import { ConnectedUser } from './types';
|
||||
import { handleWebSocketMessage } from './handler';
|
||||
|
||||
vi.mock('../services/server-access.service', () => ({
|
||||
authorizeWebSocketJoin: vi.fn(async () => ({ allowed: true as const })),
|
||||
findServerMembership: vi.fn(async () => ({ id: 'membership-1' })),
|
||||
usersShareServerMembership: vi.fn(async () => false)
|
||||
}));
|
||||
|
||||
vi.mock('../services/session-auth.service', () => ({
|
||||
consumeSessionToken: vi.fn(async (token: string) => {
|
||||
if (token !== 'valid-token') {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
token,
|
||||
user: {
|
||||
id: 'user-1',
|
||||
username: 'alice',
|
||||
displayName: 'Alice',
|
||||
passwordHash: 'hash',
|
||||
createdAt: Date.now()
|
||||
},
|
||||
issuedAt: Date.now(),
|
||||
expiresAt: Date.now() + 60_000
|
||||
};
|
||||
})
|
||||
}));
|
||||
|
||||
function createMockWs(): WebSocket & { sentMessages: string[] } {
|
||||
const sent: string[] = [];
|
||||
const ws = {
|
||||
readyState: WebSocket.OPEN,
|
||||
send: (data: string) => { sent.push(data); },
|
||||
close: () => {},
|
||||
sentMessages: sent
|
||||
} as unknown as WebSocket & { sentMessages: string[] };
|
||||
|
||||
return ws;
|
||||
}
|
||||
|
||||
function createConnectedUser(connectionId: string): ConnectedUser {
|
||||
const ws = createMockWs();
|
||||
const user: ConnectedUser = {
|
||||
oderId: connectionId,
|
||||
ws,
|
||||
authenticated: false,
|
||||
serverIds: new Set(),
|
||||
displayName: 'Test User',
|
||||
lastPong: Date.now()
|
||||
};
|
||||
|
||||
connectedUsers.set(connectionId, user);
|
||||
|
||||
return user;
|
||||
}
|
||||
|
||||
describe('server websocket handler - authentication', () => {
|
||||
beforeEach(() => {
|
||||
connectedUsers.clear();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('rejects non-identify messages until the connection is authenticated', async () => {
|
||||
createConnectedUser('conn-1');
|
||||
|
||||
await handleWebSocketMessage('conn-1', { type: 'typing', serverId: 'server-1' });
|
||||
|
||||
const user = connectedUsers.get('conn-1');
|
||||
const sentMessages = (user?.ws as WebSocket & { sentMessages: string[] }).sentMessages;
|
||||
const response = JSON.parse(sentMessages[0]) as { type: string };
|
||||
|
||||
expect(response.type).toBe('auth_required');
|
||||
expect(user?.authenticated).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects identify without a session token', async () => {
|
||||
createConnectedUser('conn-1');
|
||||
|
||||
await handleWebSocketMessage('conn-1', {
|
||||
type: 'identify',
|
||||
oderId: 'user-1',
|
||||
displayName: 'Alice'
|
||||
});
|
||||
|
||||
const user = connectedUsers.get('conn-1');
|
||||
const sentMessages = (user?.ws as WebSocket & { sentMessages: string[] }).sentMessages;
|
||||
const response = JSON.parse(sentMessages[0]) as { type: string; code: string };
|
||||
|
||||
expect(response.type).toBe('auth_error');
|
||||
expect(response.code).toBe('MISSING_TOKEN');
|
||||
expect(user?.authenticated).toBe(false);
|
||||
});
|
||||
|
||||
it('binds identify to the authenticated user id from the token', async () => {
|
||||
createConnectedUser('conn-1');
|
||||
|
||||
await handleWebSocketMessage('conn-1', {
|
||||
type: 'identify',
|
||||
token: 'valid-token',
|
||||
oderId: 'user-1',
|
||||
displayName: 'Alice'
|
||||
});
|
||||
|
||||
const user = connectedUsers.get('conn-1');
|
||||
|
||||
expect(user?.authenticated).toBe(true);
|
||||
expect(user?.oderId).toBe('user-1');
|
||||
});
|
||||
});
|
||||
@@ -63,6 +63,7 @@ function createConnectedUser(
|
||||
displayName: `User ${oderId}`,
|
||||
lastPong: Date.now(),
|
||||
oderId,
|
||||
authenticated: true,
|
||||
serverIds: new Set(),
|
||||
ws: createMockWs(),
|
||||
...overrides
|
||||
|
||||
@@ -14,6 +14,41 @@ vi.mock('../services/server-access.service', () => ({
|
||||
authorizeWebSocketJoin: vi.fn(async () => ({ allowed: true as const }))
|
||||
}));
|
||||
|
||||
let authenticatedUserId = 'user-1';
|
||||
|
||||
vi.mock('../services/session-auth.service', () => ({
|
||||
consumeSessionToken: vi.fn(async (token: string) => {
|
||||
if (token !== 'test-token') {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
token,
|
||||
user: {
|
||||
id: authenticatedUserId,
|
||||
username: 'test-user',
|
||||
displayName: 'Test User',
|
||||
passwordHash: 'hash',
|
||||
createdAt: Date.now()
|
||||
},
|
||||
issuedAt: Date.now(),
|
||||
expiresAt: Date.now() + 60_000
|
||||
};
|
||||
})
|
||||
}));
|
||||
|
||||
vi.mock('../services/plugin-support.service', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../services/plugin-support.service')>();
|
||||
|
||||
return {
|
||||
...actual,
|
||||
getPluginRequirementsSnapshot: vi.fn(async () => ({
|
||||
requirements: [],
|
||||
eventDefinitions: []
|
||||
}))
|
||||
};
|
||||
});
|
||||
|
||||
/**
|
||||
* Minimal mock WebSocket that records sent messages.
|
||||
*/
|
||||
@@ -38,6 +73,7 @@ function createConnectedUser(
|
||||
const user: ConnectedUser = {
|
||||
oderId,
|
||||
ws,
|
||||
authenticated: true,
|
||||
serverIds: new Set(),
|
||||
displayName: 'Test User',
|
||||
lastPong: Date.now(),
|
||||
@@ -168,7 +204,8 @@ describe('server websocket handler - status_update', () => {
|
||||
getSentMessagesStore(user2).sentMessages.length = 0;
|
||||
|
||||
// Identify first (required for handler)
|
||||
await handleWebSocketMessage('conn-1', { type: 'identify', oderId: 'user-1', displayName: 'User 1' });
|
||||
authenticatedUserId = 'user-1';
|
||||
await handleWebSocketMessage('conn-1', { type: 'identify', token: 'test-token', oderId: 'user-1', displayName: 'User 1' });
|
||||
|
||||
// user-2 joins server -> should receive server_users with user-1's status
|
||||
getSentMessagesStore(user2).sentMessages.length = 0;
|
||||
@@ -201,7 +238,8 @@ describe('server websocket handler - user_joined includes status', () => {
|
||||
getRequiredConnectedUser('conn-1').status = 'busy';
|
||||
|
||||
// Identify user-1
|
||||
await handleWebSocketMessage('conn-1', { type: 'identify', oderId: 'user-1', displayName: 'User 1' });
|
||||
authenticatedUserId = 'user-1';
|
||||
await handleWebSocketMessage('conn-1', { type: 'identify', token: 'test-token', oderId: 'user-1', displayName: 'User 1' });
|
||||
|
||||
getSentMessagesStore(user2).sentMessages.length = 0;
|
||||
|
||||
@@ -237,8 +275,10 @@ describe('server websocket handler - profile metadata in presence messages', ()
|
||||
bob.serverIds.add('server-1');
|
||||
getSentMessagesStore(bob).sentMessages.length = 0;
|
||||
|
||||
authenticatedUserId = 'user-1';
|
||||
await handleWebSocketMessage('conn-1', {
|
||||
type: 'identify',
|
||||
token: 'test-token',
|
||||
oderId: 'user-1',
|
||||
displayName: 'Alice Updated',
|
||||
description: 'Updated bio',
|
||||
@@ -261,8 +301,10 @@ describe('server websocket handler - profile metadata in presence messages', ()
|
||||
alice.serverIds.add('server-1');
|
||||
bob.serverIds.add('server-1');
|
||||
|
||||
authenticatedUserId = 'user-1';
|
||||
await handleWebSocketMessage('conn-1', {
|
||||
type: 'identify',
|
||||
token: 'test-token',
|
||||
oderId: 'user-1',
|
||||
displayName: 'Alice',
|
||||
description: 'Alice bio',
|
||||
@@ -291,8 +333,10 @@ describe('server websocket handler - profile metadata in presence messages', ()
|
||||
alice.serverIds.add('server-1');
|
||||
bob.serverIds.add('server-1');
|
||||
|
||||
authenticatedUserId = 'user-1';
|
||||
await handleWebSocketMessage('conn-1', {
|
||||
type: 'identify',
|
||||
token: 'test-token',
|
||||
oderId: 'user-1',
|
||||
displayName: 'Alice',
|
||||
homeSignalServerUrl: 'http://signal.example.com:3001/'
|
||||
|
||||
@@ -7,7 +7,12 @@ import {
|
||||
getUniqueUsersInServer,
|
||||
isOderIdConnectedToServer
|
||||
} from './broadcast';
|
||||
import { authorizeWebSocketJoin } from '../services/server-access.service';
|
||||
import {
|
||||
authorizeWebSocketJoin,
|
||||
findServerMembership,
|
||||
usersShareServerMembership
|
||||
} from '../services/server-access.service';
|
||||
import { consumeSessionToken } from '../services/session-auth.service';
|
||||
import {
|
||||
getPluginRequirementsSnapshot,
|
||||
PluginSupportError,
|
||||
@@ -131,8 +136,67 @@ async function sendPluginRequirements(user: ConnectedUser, serverId: string): Pr
|
||||
}
|
||||
}
|
||||
|
||||
function handleIdentify(user: ConnectedUser, message: WsMessage, connectionId: string): void {
|
||||
const newOderId = readMessageId(message['oderId']) ?? connectionId;
|
||||
const DIRECT_SIGNALING_TYPES = new Set([
|
||||
'direct-message',
|
||||
'direct-message-status',
|
||||
'direct-message-mutation',
|
||||
'direct-message-typing',
|
||||
'direct-message-sync-request',
|
||||
'direct-message-sync',
|
||||
'direct-call'
|
||||
]);
|
||||
const SERVER_SCOPED_SIGNALING_TYPES = new Set([
|
||||
'server_icon_peer_request',
|
||||
'server_icon_peer_data',
|
||||
'server_icon_available',
|
||||
'server_icon_sync_request'
|
||||
]);
|
||||
|
||||
function sendAuthRequired(user: ConnectedUser): void {
|
||||
user.ws.send(JSON.stringify({
|
||||
type: 'auth_required',
|
||||
message: 'identify with a valid session token before sending messages'
|
||||
}));
|
||||
}
|
||||
|
||||
async function handleIdentify(user: ConnectedUser, message: WsMessage, connectionId: string): Promise<void> {
|
||||
const token = typeof message['token'] === 'string' ? message['token'].trim() : '';
|
||||
|
||||
if (!token) {
|
||||
user.ws.send(JSON.stringify({
|
||||
type: 'auth_error',
|
||||
code: 'MISSING_TOKEN',
|
||||
message: 'identify requires a session token'
|
||||
}));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const session = await consumeSessionToken(token);
|
||||
|
||||
if (!session) {
|
||||
user.ws.send(JSON.stringify({
|
||||
type: 'auth_error',
|
||||
code: 'INVALID_TOKEN',
|
||||
message: 'invalid or expired session token'
|
||||
}));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const claimedOderId = readMessageId(message['oderId']);
|
||||
|
||||
if (claimedOderId && claimedOderId !== session.user.id) {
|
||||
user.ws.send(JSON.stringify({
|
||||
type: 'auth_error',
|
||||
code: 'USER_ID_MISMATCH',
|
||||
message: 'oderId must match the authenticated user'
|
||||
}));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const newOderId = session.user.id;
|
||||
const newScope = typeof message['connectionScope'] === 'string' ? message['connectionScope'] : undefined;
|
||||
const previousDisplayName = normalizeDisplayName(user.displayName);
|
||||
const previousDescription = user.description;
|
||||
@@ -140,6 +204,7 @@ function handleIdentify(user: ConnectedUser, message: WsMessage, connectionId: s
|
||||
const previousHomeSignalServerUrl = user.homeSignalServerUrl;
|
||||
|
||||
user.oderId = newOderId;
|
||||
user.authenticated = true;
|
||||
user.displayName = normalizeDisplayName(message['displayName'], normalizeDisplayName(user.displayName));
|
||||
|
||||
if (Object.prototype.hasOwnProperty.call(message, 'description')) {
|
||||
@@ -277,11 +342,45 @@ function handleLeaveServer(user: ConnectedUser, message: WsMessage, connectionId
|
||||
);
|
||||
}
|
||||
|
||||
function forwardRtcMessage(user: ConnectedUser, message: WsMessage): void {
|
||||
async function canForwardRtcMessage(user: ConnectedUser, message: WsMessage, targetUserId: string): Promise<boolean> {
|
||||
if (!targetUserId || targetUserId === user.oderId) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DIRECT_SIGNALING_TYPES.has(message.type)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (SERVER_SCOPED_SIGNALING_TYPES.has(message.type)) {
|
||||
const serverId = readMessageId(message['serverId']);
|
||||
|
||||
if (!serverId) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const senderMembership = await findServerMembership(serverId, user.oderId);
|
||||
const targetMembership = await findServerMembership(serverId, targetUserId);
|
||||
|
||||
return !!senderMembership && !!targetMembership;
|
||||
}
|
||||
|
||||
if (message.type === 'offer' || message.type === 'answer' || message.type === 'ice_candidate') {
|
||||
return true;
|
||||
}
|
||||
|
||||
return usersShareServerMembership(user.oderId, targetUserId);
|
||||
}
|
||||
|
||||
async function forwardRtcMessage(user: ConnectedUser, message: WsMessage): Promise<void> {
|
||||
const targetUserId = readMessageId(message['targetUserId']) ?? '';
|
||||
|
||||
console.log(`Forwarding ${message.type} from ${user.oderId} to ${targetUserId}`);
|
||||
|
||||
if (!(await canForwardRtcMessage(user, message, targetUserId))) {
|
||||
console.log(`Blocked ${message.type} relay from ${user.oderId} to ${targetUserId}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const targetUser = findUserByOderId(targetUserId);
|
||||
|
||||
if (targetUser) {
|
||||
@@ -482,13 +581,18 @@ export async function handleWebSocketMessage(connectionId: string, message: WsMe
|
||||
user.lastPong = Date.now();
|
||||
connectedUsers.set(connectionId, user);
|
||||
|
||||
if (!user.authenticated && message.type !== 'identify' && message.type !== 'keepalive') {
|
||||
sendAuthRequired(user);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (message.type) {
|
||||
case 'keepalive':
|
||||
user.ws.send(JSON.stringify({ type: 'keepalive_ack', serverTime: Date.now() }));
|
||||
break;
|
||||
|
||||
case 'identify':
|
||||
handleIdentify(user, message, connectionId);
|
||||
await handleIdentify(user, message, connectionId);
|
||||
break;
|
||||
|
||||
case 'join_server':
|
||||
@@ -515,7 +619,7 @@ export async function handleWebSocketMessage(connectionId: string, message: WsMe
|
||||
case 'direct-call':
|
||||
case 'server_icon_peer_request':
|
||||
case 'server_icon_peer_data':
|
||||
forwardRtcMessage(user, message);
|
||||
await forwardRtcMessage(user, message);
|
||||
break;
|
||||
|
||||
case 'chat_message':
|
||||
|
||||
@@ -80,7 +80,13 @@ export function setupWebSocket(server: Server<typeof IncomingMessage, typeof Ser
|
||||
const connectionId = uuidv4();
|
||||
const now = Date.now();
|
||||
|
||||
connectedUsers.set(connectionId, { oderId: connectionId, ws, serverIds: new Set(), lastPong: now });
|
||||
connectedUsers.set(connectionId, {
|
||||
oderId: connectionId,
|
||||
ws,
|
||||
authenticated: false,
|
||||
serverIds: new Set(),
|
||||
lastPong: now
|
||||
});
|
||||
|
||||
ws.on('pong', () => {
|
||||
const user = connectedUsers.get(connectionId);
|
||||
|
||||
@@ -3,6 +3,7 @@ import { WebSocket } from 'ws';
|
||||
export interface ConnectedUser {
|
||||
oderId: string;
|
||||
ws: WebSocket;
|
||||
authenticated: boolean;
|
||||
serverIds: Set<string>;
|
||||
viewedServerId?: string;
|
||||
displayName?: string;
|
||||
|
||||
Reference in New Issue
Block a user