feat(textprocessing): TextProcessingManager::runTask calls TaskProcessingManager::runTask

Signed-off-by: Julien Veyssier <julien-nc@posteo.net>
This commit is contained in:
Julien Veyssier 2024-08-28 11:50:23 +02:00
parent a7406306c7
commit 17e48ee1a7
No known key found for this signature in database
GPG key ID: 4141FEE162030638
2 changed files with 84 additions and 3 deletions

View file

@ -87,7 +87,6 @@ class Manager implements IManager {
private IEventDispatcher $dispatcher,
IAppDataFactory $appDataFactory,
private IRootFolder $rootFolder,
private \OCP\TextProcessing\IManager $textProcessingManager,
private \OCP\TextToImage\IManager $textToImageManager,
private \OCP\SpeechToText\ISpeechToTextManager $speechToTextManager,
private IUserMountCache $userMountCache,
@ -98,8 +97,34 @@ class Manager implements IManager {
}
/**
* This is almost a copy of textProcessingManager->getProviders
* to avoid a dependency cycle between TextProcessingManager and TaskProcessingManager
*/
private function _getRawTextProcessingProviders(): array {
$context = $this->coordinator->getRegistrationContext();
if ($context === null) {
return [];
}
$providers = [];
foreach ($context->getTextProcessingProviders() as $providerServiceRegistration) {
$class = $providerServiceRegistration->getService();
try {
$providers[$class] = $this->serverContainer->get($class);
} catch (\Throwable $e) {
$this->logger->error('Failed to load Text processing provider ' . $class, [
'exception' => $e,
]);
}
}
return $providers;
}
private function _getTextProcessingProviders(): array {
$oldProviders = $this->textProcessingManager->getProviders();
$oldProviders = $this->_getRawTextProcessingProviders();
$newProviders = [];
foreach ($oldProviders as $oldProvider) {
$provider = new class($oldProvider) implements IProvider, ISynchronousProvider {
@ -190,7 +215,7 @@ class Manager implements IManager {
* @return ITaskType[]
*/
private function _getTextProcessingTaskTypes(): array {
$oldProviders = $this->textProcessingManager->getProviders();
$oldProviders = $this->_getRawTextProcessingProviders();
$newTaskTypes = [];
foreach ($oldProviders as $oldProvider) {
// These are already implemented in the TaskProcessing realm

View file

@ -20,13 +20,22 @@ use OCP\DB\Exception;
use OCP\IConfig;
use OCP\IServerContainer;
use OCP\PreConditionNotMetException;
use OCP\TaskProcessing\IManager as TaskProcessingIManager;
use OCP\TaskProcessing\TaskTypes\TextToText;
use OCP\TaskProcessing\TaskTypes\TextToTextHeadline;
use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
use OCP\TaskProcessing\TaskTypes\TextToTextTopics;
use OCP\TextProcessing\Exception\TaskFailureException;
use OCP\TextProcessing\FreePromptTaskType;
use OCP\TextProcessing\HeadlineTaskType;
use OCP\TextProcessing\IManager;
use OCP\TextProcessing\IProvider;
use OCP\TextProcessing\IProviderWithExpectedRuntime;
use OCP\TextProcessing\IProviderWithId;
use OCP\TextProcessing\SummaryTaskType;
use OCP\TextProcessing\Task;
use OCP\TextProcessing\Task as OCPTask;
use OCP\TextProcessing\TopicsTaskType;
use Psr\Log\LoggerInterface;
use RuntimeException;
use Throwable;
@ -42,6 +51,7 @@ class Manager implements IManager {
private IJobList $jobList,
private TaskMapper $taskMapper,
private IConfig $config,
private TaskProcessingIManager $taskProcessingManager,
) {
}
@ -98,6 +108,52 @@ class Manager implements IManager {
* @inheritDoc
*/
public function runTask(OCPTask $task): string {
// try to run a task processing task if possible
$taskTypeClass = $task->getType();
$taskProcessingCompatibleTaskTypes = [
FreePromptTaskType::class => TextToText::ID,
HeadlineTaskType::class => TextToTextHeadline::ID,
SummaryTaskType::class => TextToTextSummary::ID,
TopicsTaskType::class => TextToTextTopics::ID,
];
if (isset($taskProcessingCompatibleTaskTypes[$taskTypeClass])) {
try {
$taskProcessingTaskTypeId = $taskProcessingCompatibleTaskTypes[$taskTypeClass];
$taskProcessingTask = new \OCP\TaskProcessing\Task(
$taskProcessingTaskTypeId,
['input' => $task->getInput()],
$task->getAppId(),
$task->getUserId(),
$task->getIdentifier(),
);
$task->setStatus(OCPTask::STATUS_RUNNING);
if ($task->getId() === null) {
$taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task));
$task->setId($taskEntity->getId());
} else {
$this->taskMapper->update(DbTask::fromPublicTask($task));
}
$this->logger->debug('Running a TextProcessing (' . $taskTypeClass . ') task with TaskProcessing');
$taskProcessingResultTask = $this->taskProcessingManager->runTask($taskProcessingTask);
if ($taskProcessingResultTask->getStatus() === \OCP\TaskProcessing\Task::STATUS_SUCCESSFUL) {
$task->setOutput($taskProcessingResultTask->getOutput()['output'] ?? '');
$task->setStatus(OCPTask::STATUS_SUCCESSFUL);
$this->taskMapper->update(DbTask::fromPublicTask($task));
return $task->getOutput();
}
} catch (\Throwable $e) {
$this->logger->error('TextProcessing to TaskProcessing failed', ['exception' => $e]);
$task->setStatus(OCPTask::STATUS_FAILED);
$this->taskMapper->update(DbTask::fromPublicTask($task));
throw new TaskFailureException('TextProcessing to TaskProcessing failed: ' . $e->getMessage(), 0, $e);
}
$task->setStatus(OCPTask::STATUS_FAILED);
$this->taskMapper->update(DbTask::fromPublicTask($task));
throw new TaskFailureException('Could not run task');
}
// try to run the text processing task
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No text processing provider is installed that can handle this task');
}