rough rate limiting for websockets

This commit is contained in:
dakkar 2024-08-15 11:35:51 +01:00
parent 01958da57f
commit 311a31da58
2 changed files with 55 additions and 0 deletions

View File

@ -19,7 +19,12 @@ import { ChannelFollowingService } from '@/core/ChannelFollowingService.js';
import { AuthenticateService, AuthenticationError } from './AuthenticateService.js'; import { AuthenticateService, AuthenticationError } from './AuthenticateService.js';
import MainStreamConnection from './stream/Connection.js'; import MainStreamConnection from './stream/Connection.js';
import { ChannelsService } from './stream/ChannelsService.js'; import { ChannelsService } from './stream/ChannelsService.js';
import { RateLimiterService } from './RateLimiterService.js';
import { RoleService } from '@/core/RoleService.js';
import { getIpHash } from '@/misc/get-ip-hash.js';
import ms from 'ms';
import type * as http from 'node:http'; import type * as http from 'node:http';
import type { IEndpointMeta } from './endpoints.js';
@Injectable() @Injectable()
export class StreamingApiServerService { export class StreamingApiServerService {
@ -41,9 +46,32 @@ export class StreamingApiServerService {
private notificationService: NotificationService, private notificationService: NotificationService,
private usersService: UserService, private usersService: UserService,
private channelFollowingService: ChannelFollowingService, private channelFollowingService: ChannelFollowingService,
private rateLimiterService: RateLimiterService,
private roleService: RoleService,
) { ) {
} }
@bindThis
private async rateLimitThis(
user: MiLocalUser | null | undefined,
requestIp: string | undefined,
limit: IEndpointMeta['limit'] & { key: NonNullable<string> },
) : Promise<boolean> {
let limitActor: string;
if (user) {
limitActor = user.id;
} else {
limitActor = getIpHash(requestIp || 'wtf');
}
const factor = user ? (await this.roleService.getUserPolicies(user.id)).rateLimitFactor : 1;
if (factor <= 0) return false;
// Rate limit
return await this.rateLimiterService.limit(limit, limitActor, factor).then(() => { return false }).catch(err => { return true });
}
@bindThis @bindThis
public attach(server: http.Server): void { public attach(server: http.Server): void {
this.#wss = new WebSocket.WebSocketServer({ this.#wss = new WebSocket.WebSocketServer({
@ -57,6 +85,17 @@ export class StreamingApiServerService {
return; return;
} }
if (await this.rateLimitThis(null, request.socket.remoteAddress, {
key: 'wsconnect',
duration: ms('1min'),
max: 20,
minInterval: ms('1sec'),
})) {
socket.write('HTTP/1.1 429 Rate Limit Exceeded\r\n\r\n');
socket.destroy();
return;
}
const q = new URL(request.url, `http://${request.headers.host}`).searchParams; const q = new URL(request.url, `http://${request.headers.host}`).searchParams;
let user: MiLocalUser | null = null; let user: MiLocalUser | null = null;
@ -94,6 +133,14 @@ export class StreamingApiServerService {
return; return;
} }
const rateLimiter = () => {
return this.rateLimitThis(user, request.socket.remoteAddress, {
key: 'wsmessage',
duration: ms('1sec'),
max: 100,
});
};
const stream = new MainStreamConnection( const stream = new MainStreamConnection(
this.channelsService, this.channelsService,
this.noteReadService, this.noteReadService,
@ -101,6 +148,7 @@ export class StreamingApiServerService {
this.cacheService, this.cacheService,
this.channelFollowingService, this.channelFollowingService,
user, app, user, app,
rateLimiter,
); );
await stream.init(); await stream.init();

View File

@ -25,6 +25,7 @@ import type Channel from './channel.js';
export default class Connection { export default class Connection {
public user?: MiUser; public user?: MiUser;
public token?: MiAccessToken; public token?: MiAccessToken;
private rateLimiter?: () => Promise<boolean>;
private wsConnection: WebSocket.WebSocket; private wsConnection: WebSocket.WebSocket;
public subscriber: StreamEventEmitter; public subscriber: StreamEventEmitter;
private channels: Channel[] = []; private channels: Channel[] = [];
@ -48,9 +49,11 @@ export default class Connection {
user: MiUser | null | undefined, user: MiUser | null | undefined,
token: MiAccessToken | null | undefined, token: MiAccessToken | null | undefined,
rateLimiter: () => Promise<boolean>,
) { ) {
if (user) this.user = user; if (user) this.user = user;
if (token) this.token = token; if (token) this.token = token;
if (rateLimiter) this.rateLimiter = rateLimiter;
} }
@bindThis @bindThis
@ -103,6 +106,10 @@ export default class Connection {
private async onWsConnectionMessage(data: WebSocket.RawData) { private async onWsConnectionMessage(data: WebSocket.RawData) {
let obj: Record<string, any>; let obj: Record<string, any>;
if (this.rateLimiter && await this.rateLimiter()) {
return;
}
try { try {
obj = JSON.parse(data.toString()); obj = JSON.parse(data.toString());
} catch (e) { } catch (e) {