Limit number of rate limit requests

This commit is contained in:
Julia Johannesen 2024-08-16 17:13:20 -04:00
parent 4cd44130e0
commit 6d3f9503ed
No known key found for this signature in database
GPG Key ID: 4A1377AF3E7FBC46
2 changed files with 29 additions and 2 deletions

View File

@ -26,12 +26,15 @@ import proxyAddr from 'proxy-addr';
import ms from 'ms'; import ms from 'ms';
import type * as http from 'node:http'; import type * as http from 'node:http';
import type { IEndpointMeta } from './endpoints.js'; import type { IEndpointMeta } from './endpoints.js';
import { LoggerService } from '@/core/LoggerService.js';
import type Logger from '@/logger.js';
@Injectable() @Injectable()
export class StreamingApiServerService { export class StreamingApiServerService {
#wss: WebSocket.WebSocketServer; #wss: WebSocket.WebSocketServer;
#connections = new Map<WebSocket.WebSocket, number>(); #connections = new Map<WebSocket.WebSocket, number>();
#cleanConnectionsIntervalId: NodeJS.Timeout | null = null; #cleanConnectionsIntervalId: NodeJS.Timeout | null = null;
#logger: Logger;
constructor( constructor(
@Inject(DI.redisForSub) @Inject(DI.redisForSub)
@ -49,6 +52,7 @@ export class StreamingApiServerService {
private channelFollowingService: ChannelFollowingService, private channelFollowingService: ChannelFollowingService,
private rateLimiterService: RateLimiterService, private rateLimiterService: RateLimiterService,
private roleService: RoleService, private roleService: RoleService,
private loggerService: LoggerService,
) { ) {
} }
@ -155,6 +159,7 @@ export class StreamingApiServerService {
this.notificationService, this.notificationService,
this.cacheService, this.cacheService,
this.channelFollowingService, this.channelFollowingService,
this.loggerService,
user, app, user, app,
rateLimiter, rateLimiter,
); );

View File

@ -17,6 +17,8 @@ import { ChannelFollowingService } from '@/core/ChannelFollowingService.js';
import type { ChannelsService } from './ChannelsService.js'; import type { ChannelsService } from './ChannelsService.js';
import type { EventEmitter } from 'events'; import type { EventEmitter } from 'events';
import type Channel from './channel.js'; import type Channel from './channel.js';
import { LoggerService } from '@/core/LoggerService.js';
import type Logger from '@/logger.js';
/** /**
* Main stream connection * Main stream connection
@ -39,6 +41,9 @@ export default class Connection {
public userIdsWhoMeMutingRenotes: Set<string> = new Set(); public userIdsWhoMeMutingRenotes: Set<string> = new Set();
public userMutedInstances: Set<string> = new Set(); public userMutedInstances: Set<string> = new Set();
private fetchIntervalId: NodeJS.Timeout | null = null; private fetchIntervalId: NodeJS.Timeout | null = null;
private activeRateLimitRequests: number = 0;
private closingConnection: boolean = false;
private logger: Logger;
constructor( constructor(
private channelsService: ChannelsService, private channelsService: ChannelsService,
@ -46,6 +51,7 @@ export default class Connection {
private notificationService: NotificationService, private notificationService: NotificationService,
private cacheService: CacheService, private cacheService: CacheService,
private channelFollowingService: ChannelFollowingService, private channelFollowingService: ChannelFollowingService,
private loggerService: LoggerService,
user: MiUser | null | undefined, user: MiUser | null | undefined,
token: MiAccessToken | null | undefined, token: MiAccessToken | null | undefined,
@ -54,6 +60,8 @@ export default class Connection {
if (user) this.user = user; if (user) this.user = user;
if (token) this.token = token; if (token) this.token = token;
if (rateLimiter) this.rateLimiter = rateLimiter; if (rateLimiter) this.rateLimiter = rateLimiter;
this.logger = loggerService.getLogger('streaming', 'coral', false);
} }
@bindThis @bindThis
@ -106,9 +114,23 @@ export default class Connection {
private async onWsConnectionMessage(data: WebSocket.RawData) { private async onWsConnectionMessage(data: WebSocket.RawData) {
let obj: Record<string, any>; let obj: Record<string, any>;
if (this.rateLimiter && await this.rateLimiter()) { if (this.closingConnection) return;
if (this.rateLimiter) {
if (this.activeRateLimitRequests <= 128) {
this.activeRateLimitRequests++;
const shouldRateLimit = await this.rateLimiter();
this.activeRateLimitRequests--;
if (shouldRateLimit) return;
if (this.closingConnection) return;
} else {
this.logger.warn('Closing a connection due to an excessive influx of messages.');
this.closingConnection = true;
this.wsConnection.close(1008, 'Please stop spamming the streaming API.');
return; return;
} }
}
try { try {
obj = JSON.parse(data.toString()); obj = JSON.parse(data.toString());