1
0
voyager-api/ScrapperAPI/AgentGrpc/AgentServiceImpl.cs

200 lines
7.4 KiB
C#

using Grpc.Core;
using Grpc.AspNetCore.Server;
using Microsoft.Extensions.Options;
using ScrapperAPI.AgentGrpc;
using ScrapperAPI.Interfaces;
using ScrapperAPI.Options;
namespace ScrapperAPI.AgentGrpc;
public sealed class AgentServiceImpl : AgentService.AgentServiceBase
{
private readonly IAgentRepository _agents;
private readonly IQueueRepository _queue;
private readonly IContentRepository _content;
private readonly WorkerOptions _opts;
public AgentServiceImpl(
IAgentRepository agents,
IQueueRepository queue,
IContentRepository content,
IOptions<WorkerOptions> options)
{
_agents = agents;
_queue = queue;
_content = content;
_opts = options.Value;
}
public override async Task<RegisterAgentResponse> RegisterAgent(RegisterAgentRequest request, ServerCallContext context)
{
EnsureAgentsEnabled();
var (agentId, displayName) = (request.AgentId?.Trim(), request.DisplayName?.Trim());
if (string.IsNullOrWhiteSpace(agentId))
throw new RpcException(new Status(StatusCode.InvalidArgument, "agent_id is required"));
var thumbprint = GetClientCertThumbprint(context);
await _agents.UpsertAsync(agentId, string.IsNullOrWhiteSpace(displayName) ? null : displayName, thumbprint, context.CancellationToken);
return new RegisterAgentResponse { Ok = true };
}
public override async Task<HeartbeatResponse> Heartbeat(HeartbeatRequest request, ServerCallContext context)
{
EnsureAgentsEnabled();
var agentId = request.AgentId?.Trim();
if (string.IsNullOrWhiteSpace(agentId))
throw new RpcException(new Status(StatusCode.InvalidArgument, "agent_id is required"));
await ValidateAgentAsync(agentId, context);
await _agents.TouchAsync(agentId, context.CancellationToken);
return new HeartbeatResponse { Ok = true };
}
public override async Task<LeaseWorkResponse> LeaseWork(LeaseWorkRequest request, ServerCallContext context)
{
EnsureAgentsEnabled();
if (_opts.Mode == DistributedMode.LocalOnly)
{
return new LeaseWorkResponse
{
ServerTimeUtcMs = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()
};
}
var agentId = request.AgentId?.Trim();
if (string.IsNullOrWhiteSpace(agentId))
throw new RpcException(new Status(StatusCode.InvalidArgument, "agent_id is required"));
await ValidateAgentAsync(agentId, context);
await _agents.TouchAsync(agentId, context.CancellationToken);
var capacity = Math.Clamp(request.Capacity, 0, 1000);
if (capacity == 0)
{
return new LeaseWorkResponse
{
ServerTimeUtcMs = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()
};
}
var workerId = $"agent:{agentId}";
var leaseFor = TimeSpan.FromSeconds(Math.Max(5, _opts.LeaseSeconds));
var batch = await _queue.LeaseBatchAsync(request.SessionId, workerId, capacity, leaseFor, context.CancellationToken);
var resp = new LeaseWorkResponse
{
ServerTimeUtcMs = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()
};
foreach (var it in batch)
{
resp.Items.Add(new WorkItem
{
QueueId = it.Id,
SessionId = it.SessionId,
Url = it.Url,
LeaseExpiresUtcMs = DateTimeOffset.UtcNow.Add(leaseFor).ToUnixTimeMilliseconds()
});
}
return resp;
}
public override async Task<SubmitResultResponse> SubmitResult(SubmitResultRequest request, ServerCallContext context)
{
EnsureAgentsEnabled();
var agentId = request.AgentId?.Trim();
if (string.IsNullOrWhiteSpace(agentId))
throw new RpcException(new Status(StatusCode.InvalidArgument, "agent_id is required"));
await ValidateAgentAsync(agentId, context);
await _agents.TouchAsync(agentId, context.CancellationToken);
if (request.QueueId <= 0)
throw new RpcException(new Status(StatusCode.InvalidArgument, "queue_id must be > 0"));
var workerId = $"agent:{agentId}";
try
{
if (request.Success)
{
if (request.ContentBytes is { Length: > 0 })
{
var encoding = string.IsNullOrWhiteSpace(request.ContentEncoding) ? "gzip" : request.ContentEncoding;
var origLen = request.OriginalLength > 0 ? request.OriginalLength : 0;
var compLen = request.CompressedLength > 0 ? request.CompressedLength : request.ContentBytes.Length;
await _content.SaveCompressedAsync(
request.QueueId,
encoding,
request.ContentBytes.ToByteArray(),
origLen,
compLen,
context.CancellationToken);
}
else
{
await _content.SaveAsync(request.QueueId, request.ContentText ?? string.Empty, context.CancellationToken);
}
var ok = await _queue.MarkDoneAsync(request.QueueId, workerId, context.CancellationToken);
if (!ok)
return new SubmitResultResponse { Ok = false, Message = "Lease is not valid for this agent" };
return new SubmitResultResponse { Ok = true, Message = "Stored" };
}
var error = string.IsNullOrWhiteSpace(request.Error) ? "unknown error" : request.Error;
var failed = await _queue.MarkFailedAsync(request.QueueId, workerId, error, context.CancellationToken);
if (!failed)
return new SubmitResultResponse { Ok = false, Message = "Lease is not valid for this agent" };
return new SubmitResultResponse { Ok = true, Message = "Marked failed" };
}
catch (Exception ex)
{
throw new RpcException(new Status(StatusCode.Internal, ex.Message));
}
}
private void EnsureAgentsEnabled()
{
if (!_opts.Agents.Enabled)
throw new RpcException(new Status(StatusCode.Unavailable, "Agents are disabled"));
}
private async Task ValidateAgentAsync(string agentId, ServerCallContext context)
{
var row = await _agents.GetAsync(agentId, context.CancellationToken);
if (row is null)
throw new RpcException(new Status(StatusCode.PermissionDenied, "Agent not registered"));
if (!row.IsEnabled)
throw new RpcException(new Status(StatusCode.PermissionDenied, "Agent disabled"));
var thumbprint = GetClientCertThumbprint(context);
if (!string.Equals(row.CertThumbprint, thumbprint, StringComparison.OrdinalIgnoreCase))
throw new RpcException(new Status(StatusCode.PermissionDenied, "Client certificate does not match agent"));
}
private string GetClientCertThumbprint(ServerCallContext context)
{
if (!_opts.Agents.RequireMutualTls)
return "";
var http = context.GetHttpContext();
var cert = http.Connection.ClientCertificate;
if (cert is null)
throw new RpcException(new Status(StatusCode.Unauthenticated, "Client certificate is required"));
return (cert.Thumbprint ?? string.Empty).Replace(" ", string.Empty);
}
}