LLM OCP API: Implement private backend code + add ILanguageModelTask

Signed-off-by: Marcel Klehr <mklehr@gmx.net>
(cherry picked from commit 3413873653)
This commit is contained in:
Marcel Klehr 2023-06-16 13:06:47 +02:00
parent 457f1eb407
commit d20ee42580
12 changed files with 452 additions and 46 deletions

View file

@ -0,0 +1,58 @@
<?php
namespace OC\LanguageModel\Db;
use OCP\AppFramework\Db\Entity;
use OCP\LanguageModel\ILanguageModelTask;
/**
* @method setType(string $type)
* @method string getType()
* @method setInput(string $type)
* @method string getInput()
* @method setStatus(int $type)
* @method int getStatus()
* @method setUserId(string $type)
* @method string getuserId()
* @method setAppId(string $type)
* @method string getAppId()
*/
class Task extends Entity {
protected $type;
protected $input;
protected $status;
protected $userId;
protected $appId;
/**
* @var string[]
*/
public static array $columns = ['id', 'type', 'input', 'status', 'user_id', 'app_id'];
/**
* @var string[]
*/
public static array $fields = ['id', 'type', 'input', 'status', 'userId', 'appId'];
public function __construct() {
// add types in constructor
$this->addType('id', 'integer');
$this->addType('type', 'string');
$this->addType('input', 'string');
$this->addType('status', 'integer');
$this->addType('userId', 'string');
$this->addType('appId', 'string');
}
public static function fromLanguageModelTask(ILanguageModelTask $task): Task {
return Task::fromParams([
'type' => $task->getType(),
'status' => ILanguageModelTask::STATUS_UNKNOWN,
'input' => $task->getInput(),
'userId' => $task->getUserId(),
'appId' => $task->getAppId(),
]);
}
}

View file

@ -0,0 +1,34 @@
<?php
namespace OC\LanguageModel\Db;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\AppFramework\Db\MultipleObjectsReturnedException;
use OCP\AppFramework\Db\QBMapper;
use OCP\DB\Exception;
use OCP\IDBConnection;
/**
* @extends QBMapper<Task>
*/
class TaskMapper extends QBMapper {
public function __construct(IDBConnection $db) {
parent::__construct($db, 'oc_llm_tasks', Task::class);
}
/**
* @param int $id
* @return Task
* @throws Exception
* @throws DoesNotExistException
* @throws MultipleObjectsReturnedException
*/
public function find(int $id): Task {
$qb = $this->db->getQueryBuilder();
$qb->select(Task::$columns)
->from($this->tableName)
->where($qb->expr()->eq('id', $qb->createPositionalParameter($id)));
return $this->findEntity($qb);
}
}

View file

@ -0,0 +1,162 @@
<?php
namespace OC\LanguageModel;
use OC\AppFramework\Bootstrap\Coordinator;
use OC\LanguageModel\Db\Task;
use OC\LanguageModel\Db\TaskMapper;
use OCP\LanguageModel\AbstractLanguageModelTask;
use OCP\LanguageModel\FreePromptTask;
use OCP\LanguageModel\SummaryTask;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\AppFramework\Db\MultipleObjectsReturnedException;
use OCP\BackgroundJob\IJobList;
use OCP\DB\Exception;
use OCP\IServerContainer;
use OCP\LanguageModel\ILanguageModelManager;
use OCP\LanguageModel\ILanguageModelProvider;
use OCP\LanguageModel\ILanguageModelTask;
use OCP\LanguageModel\ISummaryProvider;
use OCP\PreConditionNotMetException;
use Psr\Container\ContainerExceptionInterface;
use Psr\Container\NotFoundExceptionInterface;
use Psr\Log\LoggerInterface;
use RuntimeException;
use Throwable;
class LanguageModelManager implements ILanguageModelManager {
/** @var ?ILanguageModelProvider[] */
private ?array $providers = null;
public function __construct(
private IServerContainer $serverContainer,
private Coordinator $coordinator,
private LoggerInterface $logger,
private IJobList $jobList,
private TaskMapper $taskMapper,
) {
}
public function getProviders(): array {
$context = $this->coordinator->getRegistrationContext();
if ($context === null) {
return [];
}
if ($this->providers !== null) {
return $this->providers;
}
$this->providers = [];
foreach ($context->getSpeechToTextProviders() as $providerServiceRegistration) {
$class = $providerServiceRegistration->getService();
try {
$this->providers[$class] = $this->serverContainer->get($class);
} catch (NotFoundExceptionInterface|ContainerExceptionInterface|Throwable $e) {
$this->logger->error('Failed to load LanguageModel provider ' . $class, [
'exception' => $e,
]);
}
}
return $this->providers;
}
public function hasProviders(): bool {
$context = $this->coordinator->getRegistrationContext();
if ($context === null) {
return false;
}
return !empty($context->getSpeechToTextProviders());
}
/**
* @inheritDoc
*/
public function getAvailableTasks(): array {
$tasks = [];
foreach ($this->getProviders() as $provider) {
$tasks[FreePromptTask::class] = true;
if ($provider instanceof ISummaryProvider) {
$tasks[SummaryTask::class] = true;
}
}
return array_keys($tasks);
}
public function canHandleTask(ILanguageModelTask $task): bool {
return !empty(array_filter($this->getAvailableTasks(), fn ($class) => $task instanceof $class));
}
/**
* @inheritDoc
*/
public function runTask(ILanguageModelTask $task): string {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
foreach ($this->getProviders() as $provider) {
if (!$task->canUseProvider($provider)) {
continue;
}
try {
$task->setStatus(ILanguageModelTask::STATUS_RUNNING);
$this->taskMapper->update(Task::fromLanguageModelTask($task));
$output = $task->visitProvider($provider);
$task->setStatus(ILanguageModelTask::STATUS_SUCCESSFUL);
$this->taskMapper->update(Task::fromLanguageModelTask($task));
return $output;
} catch (\RuntimeException $e) {
$this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
$task->setStatus(ILanguageModelTask::STATUS_FAILED);
$this->taskMapper->update(Task::fromLanguageModelTask($task));
throw $e;
} catch (\Throwable $e) {
$this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
$task->setStatus(ILanguageModelTask::STATUS_FAILED);
$this->taskMapper->update(Task::fromLanguageModelTask($task));
throw new RuntimeException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage());
}
}
throw new RuntimeException('Could not transcribe file');
}
/**
* @inheritDoc
* @throws Exception
*/
public function scheduleTask(ILanguageModelTask $task): void {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
$taskEntity = Task::fromLanguageModelTask($task);
$this->taskMapper->insert($taskEntity);
$task->setId($taskEntity->getId());
$task->setStatus(ILanguageModelTask::STATUS_SCHEDULED);
$this->jobList->add(TaskBackgroundJob::class, [
'taskId' => $task->getId()
]);
}
/**
* @param int $id The id of the task
* @return ILanguageModelTask
* @throws RuntimeException If the query failed
* @throws \ValueError If the task could not be found
*/
public function getTask(int $id): ILanguageModelTask {
try {
$taskEntity = $this->taskMapper->find($id);
return AbstractLanguageModelTask::fromTaskEntity($taskEntity);
} catch (DoesNotExistException $e) {
throw new \ValueError('Could not find task with the provided id');
} catch (MultipleObjectsReturnedException $e) {
throw new RuntimeException('Could not uniquely identify task with given id');
} catch (Exception $e) {
throw new RuntimeException('Failure while trying to find task by id: '.$e->getMessage());
}
}
}

View file

@ -0,0 +1,72 @@
<?php
declare(strict_types=1);
/**
* @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net>
*
* @author Marcel Klehr <mklehr@gmx.net>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
namespace OC\LanguageModel;
use OC\User\NoUserException;
use OCP\AppFramework\Utility\ITimeFactory;
use OCP\BackgroundJob\QueuedJob;
use OCP\EventDispatcher\IEventDispatcher;
use OCP\Files\File;
use OCP\Files\IRootFolder;
use OCP\Files\NotFoundException;
use OCP\Files\NotPermittedException;
use OCP\LanguageModel\Events\TaskFailedEvent;
use OCP\LanguageModel\Events\TaskSuccessfulEvent;
use OCP\LanguageModel\ILanguageModelManager;
use OCP\PreConditionNotMetException;
use OCP\SpeechToText\Events\TranscriptionFailedEvent;
use OCP\SpeechToText\Events\TranscriptionSuccessfulEvent;
use OCP\SpeechToText\ISpeechToTextManager;
use Psr\Log\LoggerInterface;
class TaskBackgroundJob extends QueuedJob {
public function __construct(
ITimeFactory $timeFactory,
private ILanguageModelManager $languageModelManager,
private IEventDispatcher $eventDispatcher,
) {
parent::__construct($timeFactory);
$this->setAllowParallelRuns(false);
}
/**
* @param array{taskId: int} $argument
* @inheritDoc
*/
protected function run($argument) {
$taskId = $argument['taskId'];
$task = $this->languageModelManager->getTask($taskId);
try {
$output = $this->languageModelManager->runTask($task);
$event = new TaskSuccessfulEvent($task, $output);
} catch (\RuntimeException|PreConditionNotMetException $e) {
$event = new TaskFailedEvent($task, $e->getMessage());
}
$this->eventDispatcher->dispatchTyped($event);
}
}

View file

@ -2,71 +2,91 @@
namespace OCP\LanguageModel;
abstract class AbstractLanguageModelTask {
public const STATUS_UNKNOWN = 0;
public const STATUS_RUNNING = 1;
public const STATUS_SUCCESSFUL = 2;
public const STATUS_FAILED = 4;
use OC\LanguageModel\Db\Task;
abstract class AbstractLanguageModelTask implements ILanguageModelTask {
protected ?int $id;
protected int $status = self::STATUS_UNKNOWN;
protected int $status = ILanguageModelTask::STATUS_UNKNOWN;
public function __construct(
public final function __construct(
protected string $input,
protected string $appId,
protected ?string $userId,
) {
}
/**
* @param ILanguageModelProvider $provider
* @return string
* @throws \RuntimeException
*/
abstract public function visitProvider(ILanguageModelProvider $provider): string;
abstract public function canUseProvider(ILanguageModelProvider $provider): bool;
abstract public function getType(): string;
/**
* @return int
*/
public function getStatus(): int {
public final function getStatus(): int {
return $this->status;
}
/**
* @param int $status
*/
public function setStatus(int $status): void {
public final function setStatus(int $status): void {
$this->status = $status;
}
/**
* @return int|null
*/
public function getId(): ?int {
public final function getId(): ?int {
return $this->id;
}
/**
* @param int|null $id
*/
public function setId(?int $id): void {
public final function setId(?int $id): void {
$this->id = $id;
}
/**
* @return string
*/
public function getInput(): string {
public final function getInput(): string {
return $this->input;
}
/**
* @return string
*/
public function getAppId(): string {
public final function getAppId(): string {
return $this->appId;
}
/**
* @return string|null
*/
public function getUserId(): ?string {
public final function getUserId(): ?string {
return $this->userId;
}
public final static function fromTaskEntity(Task $taskEntity): ILanguageModelTask {
$task = self::factory($taskEntity->getType(), $taskEntity->getInput(), $taskEntity->getuserId(), $taskEntity->getAppId());
$task->setId($taskEntity->getId());
$task->setStatus($taskEntity->getStatus());
return $task;
}
public final static function factory(string $type, string $input, ?string $userId, string $appId): ILanguageModelTask {
if (!in_array($type, self::TYPES)) {
throw new \InvalidArgumentException('Unknown task type');
}
return new ILanguageModelTask::TYPES[$type]($input, $userId, $appId);
}
}

View file

@ -26,7 +26,7 @@ declare(strict_types=1);
namespace OCP\LanguageModel\Events;
use OCP\EventDispatcher\Event;
use OCP\LanguageModel\AbstractLanguageModelTask;
use OCP\LanguageModel\ILanguageModelTask;
/**
* @since 28.0.0
@ -36,16 +36,16 @@ abstract class AbstractLanguageModelEvent extends Event {
* @since 28.0.0
*/
public function __construct(
private AbstractLanguageModelTask $task
private ILanguageModelTask $task
) {
parent::__construct();
}
/**
* @return AbstractLanguageModelTask
* @return ILanguageModelTask
* @since 28.0.0
*/
public function getTask(): AbstractLanguageModelTask {
public function getTask(): ILanguageModelTask {
return $this->task;
}
}

View file

@ -2,14 +2,14 @@
namespace OCP\LanguageModel\Events;
use OCP\LanguageModel\AbstractLanguageModelTask;
use OCP\LanguageModel\ILanguageModelTask;
/**
* @since 28.0.0
*/
class TaskFailedEvent extends AbstractLanguageModelEvent {
public function __construct(AbstractLanguageModelTask $task,
public function __construct(ILanguageModelTask $task,
private string $errorMessage) {
parent::__construct($task);
}

View file

@ -2,14 +2,14 @@
namespace OCP\LanguageModel\Events;
use OCP\LanguageModel\AbstractLanguageModelTask;
use OCP\LanguageModel\ILanguageModelTask;
/**
* @since 28.0.0
*/
class TaskSuccessfulEvent extends AbstractLanguageModelEvent {
public function __construct(AbstractLanguageModelTask $task,
public function __construct(ILanguageModelTask $task,
private string $output) {
parent::__construct($task);
}

View file

@ -4,7 +4,8 @@ namespace OCP\LanguageModel;
use RuntimeException;
class FreePromptTask extends AbstractLanguageModelTask {
final class FreePromptTask extends AbstractLanguageModelTask {
public const TYPE = 'free_prompt';
/**
* @param ILanguageModelProvider $provider
@ -12,14 +13,14 @@ class FreePromptTask extends AbstractLanguageModelTask {
* @return string
*/
public function visitProvider(ILanguageModelProvider $provider): string {
$this->setStatus(self::STATUS_RUNNING);
try {
$output = $provider->prompt($this->getInput());
} catch (RuntimeException $e) {
$this->setStatus(self::STATUS_FAILED);
throw $e;
}
$this->setStatus(self::STATUS_SUCCESSFUL);
return $output;
return $provider->prompt($this->getInput());
}
public function canUseProvider(ILanguageModelProvider $provider): bool {
return true;
}
public function getType(): string {
return self::TYPE;
}
}

View file

@ -27,6 +27,7 @@ declare(strict_types=1);
namespace OCP\LanguageModel;
use InvalidArgumentException;
use OCP\LanguageModel\AbstractLanguageModelTask;
use OCP\LanguageModel\Events\AbstractLanguageModelEvent;
use OCP\PreConditionNotMetException;
use RuntimeException;
@ -45,11 +46,10 @@ interface ILanguageModelManager {
/**
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called
* @throws InvalidArgumentException If the file could not be found or is not of a supported type
* @throws RuntimeException If the transcription failed for other reasons
* @throws RuntimeException If something else failed
* @since 28.0.0
*/
public function runTask(AbstractLanguageModelTask $task): AbstractLanguageModelEvent;
public function runTask(ILanguageModelTask $task): string;
/**
* Will schedule an LLM inference process in the background. The result will become available
@ -58,5 +58,7 @@ interface ILanguageModelManager {
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called
* @since 28.0.0
*/
public function scheduleTask(AbstractLanguageModelTask $task) : void;
public function scheduleTask(ILanguageModelTask $task) : void;
public function getTask(int $id): ILanguageModelTask;
}

View file

@ -0,0 +1,56 @@
<?php
namespace OCP\LanguageModel;
interface ILanguageModelTask {
public const STATUS_FAILED = 4;
public const STATUS_SUCCESSFUL = 3;
public const STATUS_RUNNING = 2;
public const STATUS_SCHEDULED = 1;
public const STATUS_UNKNOWN = 0;
public const TYPES = [
SummaryTask::TYPE => SummaryTask::class,
FreePromptTask::TYPE => FreePromptTask::class,
];
/**
* @return string
*/
public function getType(): string;
/**
* @return int
*/
public function getStatus(): int;
/**
* @param int $status
*/
public function setStatus(int $status): void;
/**
* @param int|null $id
*/
public function setId(?int $id): void;
/**
* @return int|null
*/
public function getId(): ?int;
/**
* @return string
*/
public function getInput(): string;
/**
* @return string
*/
public function getAppId(): string;
/**
* @return string|null
*/
public function getUserId(): ?string;
}

View file

@ -4,7 +4,8 @@ namespace OCP\LanguageModel;
use RuntimeException;
class SummaryTask extends AbstractLanguageModelTask {
final class SummaryTask extends AbstractLanguageModelTask {
public const TYPE = 'summarize';
/**
* @param ILanguageModelProvider&ISummaryProvider $provider
@ -15,14 +16,14 @@ class SummaryTask extends AbstractLanguageModelTask {
if (!$provider instanceof ISummaryProvider) {
throw new \RuntimeException('SummaryTask#visitProvider expects ISummaryProvider');
}
$this->setStatus(self::STATUS_RUNNING);
try {
$output = $provider->summarize($this->getInput());
} catch (RuntimeException $e) {
$this->setStatus(self::STATUS_FAILED);
throw $e;
}
$this->setStatus(self::STATUS_SUCCESSFUL);
return $output;
return $provider->summarize($this->getInput());
}
public function canUseProvider(ILanguageModelProvider $provider): bool {
return $provider instanceof ISummaryProvider;
}
public function getType(): string {
return self::TYPE;
}
}