multimodalart HF Staff commited on
Commit
1125d73
·
verified ·
1 Parent(s): 2ed6d27

Upload 99 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/app/api/auth/hf/callback/route.ts +112 -0
  2. src/app/api/auth/hf/login/route.ts +36 -0
  3. src/app/api/auth/hf/validate/route.ts +22 -0
  4. src/app/api/auth/route.ts +6 -0
  5. src/app/api/caption/get/route.ts +46 -0
  6. src/app/api/datasets/create/route.tsx +25 -0
  7. src/app/api/datasets/delete/route.tsx +24 -0
  8. src/app/api/datasets/list/route.ts +25 -0
  9. src/app/api/datasets/listImages/route.ts +61 -0
  10. src/app/api/datasets/upload/route.ts +57 -0
  11. src/app/api/files/[...filePath]/route.ts +116 -0
  12. src/app/api/gpu/route.ts +121 -0
  13. src/app/api/hf-hub/route.ts +165 -0
  14. src/app/api/hf-jobs/route.ts +761 -0
  15. src/app/api/img/[...imagePath]/route.ts +78 -0
  16. src/app/api/img/caption/route.ts +29 -0
  17. src/app/api/img/delete/route.ts +34 -0
  18. src/app/api/img/upload/route.ts +58 -0
  19. src/app/api/jobs/[jobID]/delete/route.ts +32 -0
  20. src/app/api/jobs/[jobID]/files/route.ts +48 -0
  21. src/app/api/jobs/[jobID]/log/route.ts +35 -0
  22. src/app/api/jobs/[jobID]/samples/route.ts +40 -0
  23. src/app/api/jobs/[jobID]/start/route.ts +215 -0
  24. src/app/api/jobs/[jobID]/stop/route.ts +23 -0
  25. src/app/api/jobs/route.ts +67 -0
  26. src/app/api/settings/route.ts +59 -0
  27. src/app/api/zip/route.ts +78 -0
  28. src/app/apple-icon.png +0 -0
  29. src/app/dashboard/page.tsx +85 -0
  30. src/app/datasets/[datasetName]/page.tsx +190 -0
  31. src/app/datasets/page.tsx +217 -0
  32. src/app/favicon.ico +0 -0
  33. src/app/globals.css +72 -0
  34. src/app/icon.png +0 -0
  35. src/app/icon.svg +0 -0
  36. src/app/jobs/[jobID]/page.tsx +147 -0
  37. src/app/jobs/new/AdvancedJob.tsx +146 -0
  38. src/app/jobs/new/SimpleJob.tsx +973 -0
  39. src/app/jobs/new/jobConfig.ts +167 -0
  40. src/app/jobs/new/options.ts +441 -0
  41. src/app/jobs/new/page.tsx +306 -0
  42. src/app/jobs/page.tsx +49 -0
  43. src/app/layout.tsx +50 -0
  44. src/app/manifest.json +21 -0
  45. src/app/page.tsx +5 -0
  46. src/app/settings/page.tsx +264 -0
  47. src/components/AddImagesModal.tsx +152 -0
  48. src/components/AddSingleImageModal.tsx +141 -0
  49. src/components/AuthWrapper.tsx +166 -0
  50. src/components/Card.tsx +15 -0
src/app/api/auth/hf/callback/route.ts ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { cookies } from 'next/headers';
3
+
4
+ const TOKEN_ENDPOINT = 'https://huggingface.co/oauth/token';
5
+ const USERINFO_ENDPOINT = 'https://huggingface.co/oauth/userinfo';
6
+ const STATE_COOKIE = 'hf_oauth_state';
7
+
8
+ function htmlResponse(script: string) {
9
+ return new NextResponse(
10
+ `<!DOCTYPE html><html><body><script>${script}</script></body></html>`,
11
+ {
12
+ headers: { 'Content-Type': 'text/html; charset=utf-8' },
13
+ },
14
+ );
15
+ }
16
+
17
+ export async function GET(request: NextRequest) {
18
+ const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID;
19
+ const clientSecret = process.env.HF_OAUTH_CLIENT_SECRET;
20
+
21
+ if (!clientId || !clientSecret) {
22
+ return NextResponse.json({ error: 'OAuth application is not configured' }, { status: 500 });
23
+ }
24
+
25
+ const { searchParams } = new URL(request.url);
26
+ const code = searchParams.get('code');
27
+ const incomingState = searchParams.get('state');
28
+
29
+ const cookieStore = cookies();
30
+ const storedState = cookieStore.get(STATE_COOKIE)?.value;
31
+
32
+ cookieStore.delete(STATE_COOKIE);
33
+
34
+ const origin = request.nextUrl.origin;
35
+
36
+ if (!code || !incomingState || !storedState || incomingState !== storedState) {
37
+ const script = `
38
+ window.opener && window.opener.postMessage({
39
+ type: 'HF_OAUTH_ERROR',
40
+ payload: { message: 'Invalid or expired OAuth state.' }
41
+ }, '${origin}');
42
+ window.close();
43
+ `;
44
+ return htmlResponse(script.trim());
45
+ }
46
+
47
+ const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`;
48
+
49
+ try {
50
+ const tokenResponse = await fetch(TOKEN_ENDPOINT, {
51
+ method: 'POST',
52
+ headers: {
53
+ 'Content-Type': 'application/x-www-form-urlencoded',
54
+ },
55
+ body: new URLSearchParams({
56
+ grant_type: 'authorization_code',
57
+ code,
58
+ redirect_uri: redirectUri,
59
+ client_id: clientId,
60
+ client_secret: clientSecret,
61
+ }),
62
+ });
63
+
64
+ if (!tokenResponse.ok) {
65
+ const errorPayload = await tokenResponse.json().catch(() => ({}));
66
+ throw new Error(errorPayload?.error_description || 'Failed to exchange code for token');
67
+ }
68
+
69
+ const tokenData = await tokenResponse.json();
70
+ const accessToken = tokenData?.access_token;
71
+ if (!accessToken) {
72
+ throw new Error('Access token missing in response');
73
+ }
74
+
75
+ const userResponse = await fetch(USERINFO_ENDPOINT, {
76
+ headers: {
77
+ Authorization: `Bearer ${accessToken}`,
78
+ },
79
+ });
80
+
81
+ if (!userResponse.ok) {
82
+ throw new Error('Failed to fetch user info');
83
+ }
84
+
85
+ const profile = await userResponse.json();
86
+ const namespace = profile?.preferred_username || profile?.name || 'user';
87
+
88
+ const script = `
89
+ window.opener && window.opener.postMessage({
90
+ type: 'HF_OAUTH_SUCCESS',
91
+ payload: {
92
+ token: ${JSON.stringify(accessToken)},
93
+ namespace: ${JSON.stringify(namespace)},
94
+ }
95
+ }, '${origin}');
96
+ window.close();
97
+ `;
98
+
99
+ return htmlResponse(script.trim());
100
+ } catch (error: any) {
101
+ const message = error?.message || 'OAuth flow failed';
102
+ const script = `
103
+ window.opener && window.opener.postMessage({
104
+ type: 'HF_OAUTH_ERROR',
105
+ payload: { message: ${JSON.stringify(message)} }
106
+ }, '${origin}');
107
+ window.close();
108
+ `;
109
+
110
+ return htmlResponse(script.trim());
111
+ }
112
+ }
src/app/api/auth/hf/login/route.ts ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { randomUUID } from 'crypto';
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+
4
+ const HF_AUTHORIZE_URL = 'https://huggingface.co/oauth/authorize';
5
+ const STATE_COOKIE = 'hf_oauth_state';
6
+
7
+ export async function GET(request: NextRequest) {
8
+ const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID;
9
+ if (!clientId) {
10
+ return NextResponse.json({ error: 'OAuth client ID not configured' }, { status: 500 });
11
+ }
12
+
13
+ const state = randomUUID();
14
+ const origin = request.nextUrl.origin;
15
+ const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`;
16
+
17
+ const authorizeUrl = new URL(HF_AUTHORIZE_URL);
18
+ authorizeUrl.searchParams.set('response_type', 'code');
19
+ authorizeUrl.searchParams.set('client_id', clientId);
20
+ authorizeUrl.searchParams.set('redirect_uri', redirectUri);
21
+ authorizeUrl.searchParams.set('scope', 'openid profile read-repos');
22
+ authorizeUrl.searchParams.set('state', state);
23
+
24
+ const response = NextResponse.redirect(authorizeUrl.toString(), { status: 302 });
25
+ response.cookies.set({
26
+ name: STATE_COOKIE,
27
+ value: state,
28
+ httpOnly: true,
29
+ sameSite: 'lax',
30
+ secure: process.env.NODE_ENV === 'production',
31
+ maxAge: 60 * 5,
32
+ path: '/',
33
+ });
34
+
35
+ return response;
36
+ }
src/app/api/auth/hf/validate/route.ts ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { whoAmI } from '@huggingface/hub';
3
+
4
+ export async function POST(request: NextRequest) {
5
+ try {
6
+ const body = await request.json().catch(() => ({}));
7
+ const token = (body?.token || '').trim();
8
+
9
+ if (!token) {
10
+ return NextResponse.json({ error: 'Token is required' }, { status: 400 });
11
+ }
12
+
13
+ const info = await whoAmI({ accessToken: token });
14
+ return NextResponse.json({
15
+ name: info?.name || info?.username || 'user',
16
+ email: info?.email || null,
17
+ orgs: info?.orgs || [],
18
+ });
19
+ } catch (error: any) {
20
+ return NextResponse.json({ error: error?.message || 'Invalid token' }, { status: 401 });
21
+ }
22
+ }
src/app/api/auth/route.ts ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+
3
+ export async function GET() {
4
+ // if this gets hit, auth has already been verified
5
+ return NextResponse.json({ isAuthenticated: true });
6
+ }
src/app/api/caption/get/route.ts ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import path from 'path';
5
+ import { getDatasetsRoot } from '@/server/settings';
6
+
7
+ export async function POST(request: NextRequest) {
8
+
9
+ const body = await request.json();
10
+ const { imgPath } = body;
11
+ console.log('Received POST request for caption:', imgPath);
12
+ try {
13
+ // Decode the path
14
+ const filepath = imgPath;
15
+ console.log('Decoded image path:', filepath);
16
+
17
+ // caption name is the filepath without extension but with .txt
18
+ const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.txt';
19
+
20
+ // Get allowed directories
21
+ const allowedDir = await getDatasetsRoot();
22
+
23
+ // Security check: Ensure path is in allowed directory
24
+ const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..');
25
+
26
+ if (!isAllowed) {
27
+ console.warn(`Access denied: ${filepath} not in ${allowedDir}`);
28
+ return new NextResponse('Access denied', { status: 403 });
29
+ }
30
+
31
+ // Check if file exists
32
+ if (!fs.existsSync(captionPath)) {
33
+ // send back blank string if caption file does not exist
34
+ return new NextResponse('');
35
+ }
36
+
37
+ // Read caption file
38
+ const caption = fs.readFileSync(captionPath, 'utf-8');
39
+
40
+ // Return caption
41
+ return new NextResponse(caption);
42
+ } catch (error) {
43
+ console.error('Error getting caption:', error);
44
+ return new NextResponse('Error getting caption', { status: 500 });
45
+ }
46
+ }
src/app/api/datasets/create/route.tsx ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import path from 'path';
4
+ import { getDatasetsRoot } from '@/server/settings';
5
+
6
+ export async function POST(request: Request) {
7
+ try {
8
+ const body = await request.json();
9
+ let { name } = body;
10
+ // clean name by making lower case, removing special characters, and replacing spaces with underscores
11
+ name = name.toLowerCase().replace(/[^a-z0-9]+/g, '_');
12
+
13
+ let datasetsPath = await getDatasetsRoot();
14
+ let datasetPath = path.join(datasetsPath, name);
15
+
16
+ // if folder doesnt exist, create it
17
+ if (!fs.existsSync(datasetPath)) {
18
+ fs.mkdirSync(datasetPath);
19
+ }
20
+
21
+ return NextResponse.json({ success: true, name: name, path: datasetPath });
22
+ } catch (error) {
23
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
24
+ }
25
+ }
src/app/api/datasets/delete/route.tsx ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import path from 'path';
4
+ import { getDatasetsRoot } from '@/server/settings';
5
+
6
+ export async function POST(request: Request) {
7
+ try {
8
+ const body = await request.json();
9
+ const { name } = body;
10
+ let datasetsPath = await getDatasetsRoot();
11
+ let datasetPath = path.join(datasetsPath, name);
12
+
13
+ // if folder doesnt exist, ignore
14
+ if (!fs.existsSync(datasetPath)) {
15
+ return NextResponse.json({ success: true });
16
+ }
17
+
18
+ // delete it and return success
19
+ fs.rmdirSync(datasetPath, { recursive: true });
20
+ return NextResponse.json({ success: true });
21
+ } catch (error) {
22
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
23
+ }
24
+ }
src/app/api/datasets/list/route.ts ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import { getDatasetsRoot } from '@/server/settings';
4
+
5
+ export async function GET() {
6
+ try {
7
+ let datasetsPath = await getDatasetsRoot();
8
+
9
+ // if folder doesnt exist, create it
10
+ if (!fs.existsSync(datasetsPath)) {
11
+ fs.mkdirSync(datasetsPath);
12
+ }
13
+
14
+ // find all the folders in the datasets folder
15
+ let folders = fs
16
+ .readdirSync(datasetsPath, { withFileTypes: true })
17
+ .filter(dirent => dirent.isDirectory())
18
+ .filter(dirent => !dirent.name.startsWith('.'))
19
+ .map(dirent => dirent.name);
20
+
21
+ return NextResponse.json(folders);
22
+ } catch (error) {
23
+ return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 });
24
+ }
25
+ }
src/app/api/datasets/listImages/route.ts ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import path from 'path';
4
+ import { getDatasetsRoot } from '@/server/settings';
5
+
6
+ export async function POST(request: Request) {
7
+ const datasetsPath = await getDatasetsRoot();
8
+ const body = await request.json();
9
+ const { datasetName } = body;
10
+ const datasetFolder = path.join(datasetsPath, datasetName);
11
+
12
+ try {
13
+ // Check if folder exists
14
+ if (!fs.existsSync(datasetFolder)) {
15
+ return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 });
16
+ }
17
+
18
+ // Find all images recursively
19
+ const imageFiles = findImagesRecursively(datasetFolder);
20
+
21
+ // Format response
22
+ const result = imageFiles.map(imgPath => ({
23
+ img_path: imgPath,
24
+ }));
25
+
26
+ return NextResponse.json({ images: result });
27
+ } catch (error) {
28
+ console.error('Error finding images:', error);
29
+ return NextResponse.json({ error: 'Failed to process request' }, { status: 500 });
30
+ }
31
+ }
32
+
33
+ /**
34
+ * Recursively finds all image files in a directory and its subdirectories
35
+ * @param dir Directory to search
36
+ * @returns Array of absolute paths to image files
37
+ */
38
+ function findImagesRecursively(dir: string): string[] {
39
+ const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp', '.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv'];
40
+ let results: string[] = [];
41
+
42
+ const items = fs.readdirSync(dir);
43
+
44
+ for (const item of items) {
45
+ const itemPath = path.join(dir, item);
46
+ const stat = fs.statSync(itemPath);
47
+
48
+ if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) {
49
+ // If it's a directory, recursively search it
50
+ results = results.concat(findImagesRecursively(itemPath));
51
+ } else {
52
+ // If it's a file, check if it's an image
53
+ const ext = path.extname(itemPath).toLowerCase();
54
+ if (imageExtensions.includes(ext)) {
55
+ results.push(itemPath);
56
+ }
57
+ }
58
+ }
59
+
60
+ return results;
61
+ }
src/app/api/datasets/upload/route.ts ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/app/api/datasets/upload/route.ts
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import { writeFile, mkdir } from 'fs/promises';
4
+ import { join } from 'path';
5
+ import { getDatasetsRoot } from '@/server/settings';
6
+
7
+ export async function POST(request: NextRequest) {
8
+ try {
9
+ const datasetsPath = await getDatasetsRoot();
10
+ if (!datasetsPath) {
11
+ return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 });
12
+ }
13
+ const formData = await request.formData();
14
+ const files = formData.getAll('files');
15
+ const datasetName = formData.get('datasetName') as string;
16
+
17
+ if (!files || files.length === 0) {
18
+ return NextResponse.json({ error: 'No files provided' }, { status: 400 });
19
+ }
20
+
21
+ // Create upload directory if it doesn't exist
22
+ const uploadDir = join(datasetsPath, datasetName);
23
+ await mkdir(uploadDir, { recursive: true });
24
+
25
+ const savedFiles: string[] = [];
26
+
27
+ // Process files sequentially to avoid overwhelming the system
28
+ for (let i = 0; i < files.length; i++) {
29
+ const file = files[i] as any;
30
+ const bytes = await file.arrayBuffer();
31
+ const buffer = Buffer.from(bytes);
32
+
33
+ // Clean filename and ensure it's unique
34
+ const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_');
35
+ const filePath = join(uploadDir, fileName);
36
+
37
+ await writeFile(filePath, buffer);
38
+ savedFiles.push(fileName);
39
+ }
40
+
41
+ return NextResponse.json({
42
+ message: 'Files uploaded successfully',
43
+ files: savedFiles,
44
+ });
45
+ } catch (error) {
46
+ console.error('Upload error:', error);
47
+ return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
48
+ }
49
+ }
50
+
51
+ // Increase payload size limit (default is 4mb)
52
+ export const config = {
53
+ api: {
54
+ bodyParser: false,
55
+ responseLimit: '50mb',
56
+ },
57
+ };
src/app/api/files/[...filePath]/route.ts ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import path from 'path';
5
+ import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
6
+
7
+ export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) {
8
+ const { filePath } = await params;
9
+ try {
10
+ // Decode the path
11
+ const decodedFilePath = decodeURIComponent(filePath);
12
+
13
+ // Get allowed directories
14
+ const datasetRoot = await getDatasetsRoot();
15
+ const trainingRoot = await getTrainingFolder();
16
+ const allowedDirs = [datasetRoot, trainingRoot];
17
+
18
+ // Security check: Ensure path is in allowed directory
19
+ const isAllowed =
20
+ allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
21
+
22
+ if (!isAllowed) {
23
+ console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
24
+ return new NextResponse('Access denied', { status: 403 });
25
+ }
26
+
27
+ // Check if file exists
28
+ if (!fs.existsSync(decodedFilePath)) {
29
+ console.warn(`File not found: ${decodedFilePath}`);
30
+ return new NextResponse('File not found', { status: 404 });
31
+ }
32
+
33
+ // Get file info
34
+ const stat = fs.statSync(decodedFilePath);
35
+ if (!stat.isFile()) {
36
+ return new NextResponse('Not a file', { status: 400 });
37
+ }
38
+
39
+ // Get filename for Content-Disposition
40
+ const filename = path.basename(decodedFilePath);
41
+
42
+ // Determine content type
43
+ const ext = path.extname(decodedFilePath).toLowerCase();
44
+ const contentTypeMap: { [key: string]: string } = {
45
+ '.jpg': 'image/jpeg',
46
+ '.jpeg': 'image/jpeg',
47
+ '.png': 'image/png',
48
+ '.gif': 'image/gif',
49
+ '.webp': 'image/webp',
50
+ '.svg': 'image/svg+xml',
51
+ '.bmp': 'image/bmp',
52
+ '.safetensors': 'application/octet-stream',
53
+ '.zip': 'application/zip',
54
+ // Videos
55
+ '.mp4': 'video/mp4',
56
+ '.avi': 'video/x-msvideo',
57
+ '.mov': 'video/quicktime',
58
+ '.mkv': 'video/x-matroska',
59
+ '.wmv': 'video/x-ms-wmv',
60
+ '.m4v': 'video/x-m4v',
61
+ '.flv': 'video/x-flv'
62
+ };
63
+
64
+ const contentType = contentTypeMap[ext] || 'application/octet-stream';
65
+
66
+ // Get range header for partial content support
67
+ const range = request.headers.get('range');
68
+
69
+ // Common headers for better download handling
70
+ const commonHeaders = {
71
+ 'Content-Type': contentType,
72
+ 'Accept-Ranges': 'bytes',
73
+ 'Cache-Control': 'public, max-age=86400',
74
+ 'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
75
+ 'X-Content-Type-Options': 'nosniff',
76
+ };
77
+
78
+ if (range) {
79
+ // Parse range header
80
+ const parts = range.replace(/bytes=/, '').split('-');
81
+ const start = parseInt(parts[0], 10);
82
+ const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
83
+ const chunkSize = end - start + 1;
84
+
85
+ const fileStream = fs.createReadStream(decodedFilePath, {
86
+ start,
87
+ end,
88
+ highWaterMark: 64 * 1024, // 64KB buffer
89
+ });
90
+
91
+ return new NextResponse(fileStream as any, {
92
+ status: 206,
93
+ headers: {
94
+ ...commonHeaders,
95
+ 'Content-Range': `bytes ${start}-${end}/${stat.size}`,
96
+ 'Content-Length': String(chunkSize),
97
+ },
98
+ });
99
+ } else {
100
+ // For full file download, read directly without streaming wrapper
101
+ const fileStream = fs.createReadStream(decodedFilePath, {
102
+ highWaterMark: 64 * 1024, // 64KB buffer
103
+ });
104
+
105
+ return new NextResponse(fileStream as any, {
106
+ headers: {
107
+ ...commonHeaders,
108
+ 'Content-Length': String(stat.size),
109
+ },
110
+ });
111
+ }
112
+ } catch (error) {
113
+ console.error('Error serving file:', error);
114
+ return new NextResponse('Internal Server Error', { status: 500 });
115
+ }
116
+ }
src/app/api/gpu/route.ts ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import { exec } from 'child_process';
3
+ import { promisify } from 'util';
4
+ import os from 'os';
5
+
6
+ const execAsync = promisify(exec);
7
+
8
+ export async function GET() {
9
+ try {
10
+ // Get platform
11
+ const platform = os.platform();
12
+ const isWindows = platform === 'win32';
13
+
14
+ // Check if nvidia-smi is available
15
+ const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
16
+
17
+ if (!hasNvidiaSmi) {
18
+ return NextResponse.json({
19
+ hasNvidiaSmi: false,
20
+ gpus: [],
21
+ error: 'nvidia-smi not found or not accessible',
22
+ });
23
+ }
24
+
25
+ // Get GPU stats
26
+ const gpuStats = await getGpuStats(isWindows);
27
+
28
+ return NextResponse.json({
29
+ hasNvidiaSmi: true,
30
+ gpus: gpuStats,
31
+ });
32
+ } catch (error) {
33
+ console.error('Error fetching NVIDIA GPU stats:', error);
34
+ return NextResponse.json(
35
+ {
36
+ hasNvidiaSmi: false,
37
+ gpus: [],
38
+ error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`,
39
+ },
40
+ { status: 500 },
41
+ );
42
+ }
43
+ }
44
+
45
+ async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
46
+ try {
47
+ if (isWindows) {
48
+ // Check if nvidia-smi is available on Windows
49
+ // It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe
50
+ // but we'll just try to run it directly as it may be in PATH
51
+ await execAsync('nvidia-smi -L');
52
+ } else {
53
+ // Linux/macOS check
54
+ await execAsync('which nvidia-smi');
55
+ }
56
+ return true;
57
+ } catch (error) {
58
+ return false;
59
+ }
60
+ }
61
+
62
+ async function getGpuStats(isWindows: boolean) {
63
+ // Command is the same for both platforms, but the path might be different
64
+ const command =
65
+ 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
66
+
67
+ // Execute command
68
+ const { stdout } = await execAsync(command);
69
+
70
+ // Parse CSV output
71
+ const gpus = stdout
72
+ .trim()
73
+ .split('\n')
74
+ .map(line => {
75
+ const [
76
+ index,
77
+ name,
78
+ driverVersion,
79
+ temperature,
80
+ gpuUtil,
81
+ memoryUtil,
82
+ memoryTotal,
83
+ memoryFree,
84
+ memoryUsed,
85
+ powerDraw,
86
+ powerLimit,
87
+ clockGraphics,
88
+ clockMemory,
89
+ fanSpeed,
90
+ ] = line.split(', ').map(item => item.trim());
91
+
92
+ return {
93
+ index: parseInt(index),
94
+ name,
95
+ driverVersion,
96
+ temperature: parseInt(temperature),
97
+ utilization: {
98
+ gpu: parseInt(gpuUtil),
99
+ memory: parseInt(memoryUtil),
100
+ },
101
+ memory: {
102
+ total: parseInt(memoryTotal),
103
+ free: parseInt(memoryFree),
104
+ used: parseInt(memoryUsed),
105
+ },
106
+ power: {
107
+ draw: parseFloat(powerDraw),
108
+ limit: parseFloat(powerLimit),
109
+ },
110
+ clocks: {
111
+ graphics: parseInt(clockGraphics),
112
+ memory: parseInt(clockMemory),
113
+ },
114
+ fan: {
115
+ speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0
116
+ },
117
+ };
118
+ });
119
+
120
+ return gpus;
121
+ }
src/app/api/hf-hub/route.ts ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { whoAmI, createRepo, uploadFiles, datasetInfo } from '@huggingface/hub';
3
+ import { readdir, stat } from 'fs/promises';
4
+ import path from 'path';
5
+
6
+ export async function POST(request: NextRequest) {
7
+ try {
8
+ const body = await request.json();
9
+ const { action, token, namespace, datasetName, datasetPath, datasetId } = body;
10
+
11
+ if (!token) {
12
+ return NextResponse.json({ error: 'HF token is required' }, { status: 400 });
13
+ }
14
+
15
+ switch (action) {
16
+ case 'whoami':
17
+ try {
18
+ const user = await whoAmI({ accessToken: token });
19
+ return NextResponse.json({ user });
20
+ } catch (error) {
21
+ return NextResponse.json({ error: 'Invalid token or network error' }, { status: 401 });
22
+ }
23
+
24
+ case 'createDataset':
25
+ try {
26
+ if (!namespace || !datasetName) {
27
+ return NextResponse.json({ error: 'Namespace and dataset name required' }, { status: 400 });
28
+ }
29
+
30
+ const repoId = `datasets/${namespace}/${datasetName}`;
31
+
32
+ // Create repository
33
+ await createRepo({
34
+ repo: repoId,
35
+ accessToken: token,
36
+ private: false,
37
+ });
38
+
39
+ return NextResponse.json({ success: true, repoId });
40
+ } catch (error: any) {
41
+ if (error.message?.includes('already exists')) {
42
+ return NextResponse.json({ success: true, repoId: `${namespace}/${datasetName}`, exists: true });
43
+ }
44
+ return NextResponse.json({ error: error.message || 'Failed to create dataset' }, { status: 500 });
45
+ }
46
+
47
+ case 'uploadDataset':
48
+ try {
49
+ if (!namespace || !datasetName || !datasetPath) {
50
+ return NextResponse.json({ error: 'Missing required parameters' }, { status: 400 });
51
+ }
52
+
53
+ const repoId = `datasets/${namespace}/${datasetName}`;
54
+
55
+ // Check if directory exists
56
+ try {
57
+ await stat(datasetPath);
58
+ } catch {
59
+ return NextResponse.json({ error: 'Dataset path does not exist' }, { status: 400 });
60
+ }
61
+
62
+ // Read files from directory and upload them
63
+ const files = await readdir(datasetPath);
64
+ const filesToUpload = [];
65
+
66
+ for (const fileName of files) {
67
+ const filePath = path.join(datasetPath, fileName);
68
+ const fileStats = await stat(filePath);
69
+
70
+ if (fileStats.isFile()) {
71
+ filesToUpload.push({
72
+ path: fileName,
73
+ content: new URL(`file://${filePath}`)
74
+ });
75
+ }
76
+ }
77
+
78
+ if (filesToUpload.length > 0) {
79
+ await uploadFiles({
80
+ repo: repoId,
81
+ accessToken: token,
82
+ files: filesToUpload,
83
+ });
84
+ }
85
+
86
+ return NextResponse.json({ success: true, repoId });
87
+ } catch (error: any) {
88
+ console.error('Upload error:', error);
89
+ return NextResponse.json({ error: error.message || 'Failed to upload dataset' }, { status: 500 });
90
+ }
91
+
92
+ case 'listFiles':
93
+ try {
94
+ if (!datasetPath) {
95
+ return NextResponse.json({ error: 'Dataset path required' }, { status: 400 });
96
+ }
97
+
98
+ const files = await readdir(datasetPath, { withFileTypes: true });
99
+ const imageExtensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp'];
100
+
101
+ const imageFiles = files
102
+ .filter(file => file.isFile())
103
+ .filter(file => imageExtensions.some(ext => file.name.toLowerCase().endsWith(ext)))
104
+ .map(file => ({
105
+ name: file.name,
106
+ path: path.join(datasetPath, file.name),
107
+ }));
108
+
109
+ const captionFiles = files
110
+ .filter(file => file.isFile())
111
+ .filter(file => file.name.endsWith('.txt'))
112
+ .map(file => ({
113
+ name: file.name,
114
+ path: path.join(datasetPath, file.name),
115
+ }));
116
+
117
+ return NextResponse.json({
118
+ images: imageFiles,
119
+ captions: captionFiles,
120
+ total: imageFiles.length
121
+ });
122
+ } catch (error: any) {
123
+ return NextResponse.json({ error: error.message || 'Failed to list files' }, { status: 500 });
124
+ }
125
+
126
+ case 'validateDataset':
127
+ try {
128
+ if (!datasetId) {
129
+ return NextResponse.json({ error: 'Dataset ID required' }, { status: 400 });
130
+ }
131
+
132
+ // Try to get dataset info to validate it exists and is accessible
133
+ const dataset = await datasetInfo({
134
+ name: datasetId,
135
+ accessToken: token,
136
+ });
137
+
138
+ return NextResponse.json({
139
+ exists: true,
140
+ dataset: {
141
+ id: dataset.id,
142
+ author: dataset.author,
143
+ downloads: dataset.downloads,
144
+ likes: dataset.likes,
145
+ private: dataset.private,
146
+ }
147
+ });
148
+ } catch (error: any) {
149
+ if (error.message?.includes('404') || error.message?.includes('not found')) {
150
+ return NextResponse.json({ exists: false }, { status: 200 });
151
+ }
152
+ if (error.message?.includes('401') || error.message?.includes('403')) {
153
+ return NextResponse.json({ error: 'Dataset not accessible with current token' }, { status: 403 });
154
+ }
155
+ return NextResponse.json({ error: error.message || 'Failed to validate dataset' }, { status: 500 });
156
+ }
157
+
158
+ default:
159
+ return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
160
+ }
161
+ } catch (error: any) {
162
+ console.error('HF Hub API error:', error);
163
+ return NextResponse.json({ error: error.message || 'Internal server error' }, { status: 500 });
164
+ }
165
+ }
src/app/api/hf-jobs/route.ts ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { spawn } from 'child_process';
3
+ import { writeFile } from 'fs/promises';
4
+ import path from 'path';
5
+ import { tmpdir } from 'os';
6
+
7
+ export async function POST(request: NextRequest) {
8
+ try {
9
+ const body = await request.json();
10
+ const { action, token, hardware, namespace, jobConfig, datasetRepo } = body;
11
+
12
+ switch (action) {
13
+ case 'checkStatus':
14
+ try {
15
+ if (!token || !jobConfig?.hf_job_id) {
16
+ return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 });
17
+ }
18
+
19
+ const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id);
20
+ return NextResponse.json({ status: jobStatus });
21
+ } catch (error: any) {
22
+ console.error('Job status check error:', error);
23
+ return NextResponse.json({ error: error.message }, { status: 500 });
24
+ }
25
+
26
+ case 'generateScript':
27
+ try {
28
+ const uvScript = generateUVScript({
29
+ jobConfig,
30
+ datasetRepo,
31
+ namespace,
32
+ token: token || 'YOUR_HF_TOKEN',
33
+ });
34
+
35
+ return NextResponse.json({
36
+ script: uvScript,
37
+ filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py`
38
+ });
39
+ } catch (error: any) {
40
+ return NextResponse.json({ error: error.message }, { status: 500 });
41
+ }
42
+
43
+ case 'submitJob':
44
+ try {
45
+ if (!token || !hardware) {
46
+ return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 });
47
+ }
48
+
49
+ // Generate UV script
50
+ const uvScript = generateUVScript({
51
+ jobConfig,
52
+ datasetRepo,
53
+ namespace,
54
+ token,
55
+ });
56
+
57
+ // Write script to temporary file
58
+ const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`);
59
+ await writeFile(scriptPath, uvScript);
60
+
61
+ // Submit HF job using uv run
62
+ const jobId = await submitHFJobUV(token, hardware, scriptPath);
63
+
64
+ return NextResponse.json({
65
+ success: true,
66
+ jobId,
67
+ message: `Job submitted successfully with ID: ${jobId}`
68
+ });
69
+ } catch (error: any) {
70
+ console.error('Job submission error:', error);
71
+ return NextResponse.json({ error: error.message }, { status: 500 });
72
+ }
73
+
74
+ default:
75
+ return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
76
+ }
77
+ } catch (error: any) {
78
+ console.error('HF Jobs API error:', error);
79
+ return NextResponse.json({ error: error.message }, { status: 500 });
80
+ }
81
+ }
82
+
83
+ function generateUVScript({ jobConfig, datasetRepo, namespace, token }: {
84
+ jobConfig: any;
85
+ datasetRepo: string;
86
+ namespace: string;
87
+ token: string;
88
+ }) {
89
+ const config = jobConfig.config;
90
+ const process = config.process[0];
91
+
92
+ return `# /// script
93
+ # dependencies = [
94
+ # "torch>=2.0.0",
95
+ # "torchvision",
96
+ # "torchao==0.10.0",
97
+ # "safetensors",
98
+ # "diffusers @ git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63",
99
+ # "transformers==4.52.4",
100
+ # "lycoris-lora==1.8.3",
101
+ # "flatten_json",
102
+ # "pyyaml",
103
+ # "oyaml",
104
+ # "tensorboard",
105
+ # "kornia",
106
+ # "invisible-watermark",
107
+ # "einops",
108
+ # "accelerate",
109
+ # "toml",
110
+ # "albumentations==1.4.15",
111
+ # "albucore==0.0.16",
112
+ # "pydantic",
113
+ # "omegaconf",
114
+ # "k-diffusion",
115
+ # "open_clip_torch",
116
+ # "timm",
117
+ # "prodigyopt",
118
+ # "controlnet_aux==0.0.10",
119
+ # "python-dotenv",
120
+ # "bitsandbytes",
121
+ # "hf_transfer",
122
+ # "lpips",
123
+ # "pytorch_fid",
124
+ # "optimum-quanto==0.2.4",
125
+ # "sentencepiece",
126
+ # "huggingface_hub",
127
+ # "peft",
128
+ # "python-slugify",
129
+ # "opencv-python-headless",
130
+ # "pytorch-wavelets==1.3.0",
131
+ # "matplotlib==3.10.1",
132
+ # "setuptools==69.5.1",
133
+ # "datasets==4.0.0",
134
+ # "pyarrow==20.0.0",
135
+ # "pillow",
136
+ # "ftfy",
137
+ # ]
138
+ # ///
139
+
140
+ import os
141
+ import sys
142
+ import subprocess
143
+ import argparse
144
+ import oyaml as yaml
145
+ from datasets import load_dataset
146
+ from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download
147
+ import tempfile
148
+ import shutil
149
+ import glob
150
+ from PIL import Image
151
+
152
+ def setup_ai_toolkit():
153
+ """Clone and setup ai-toolkit repository"""
154
+ repo_dir = "ai-toolkit"
155
+ if not os.path.exists(repo_dir):
156
+ print("Cloning ai-toolkit repository...")
157
+ subprocess.run(
158
+ ["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir],
159
+ check=True
160
+ )
161
+ sys.path.insert(0, os.path.abspath(repo_dir))
162
+ return repo_dir
163
+
164
+ def download_dataset(dataset_repo: str, local_path: str):
165
+ """Download dataset from HF Hub as files"""
166
+ print(f"Downloading dataset from {dataset_repo}...")
167
+
168
+ # Create local dataset directory
169
+ os.makedirs(local_path, exist_ok=True)
170
+
171
+ # Use snapshot_download to get the dataset files directly
172
+ from huggingface_hub import snapshot_download
173
+
174
+ try:
175
+ # First try to download as a structured dataset
176
+ dataset = load_dataset(dataset_repo, split="train")
177
+
178
+ # Download images and captions from structured dataset
179
+ for i, item in enumerate(dataset):
180
+ # Save image
181
+ if "image" in item:
182
+ image_path = os.path.join(local_path, f"image_{i:06d}.jpg")
183
+ image = item["image"]
184
+
185
+ # Convert RGBA to RGB if necessary (for JPEG compatibility)
186
+ if image.mode == 'RGBA':
187
+ # Create a white background and paste the RGBA image on it
188
+ background = Image.new('RGB', image.size, (255, 255, 255))
189
+ background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask
190
+ image = background
191
+ elif image.mode not in ['RGB', 'L']:
192
+ # Convert any other mode to RGB
193
+ image = image.convert('RGB')
194
+
195
+ image.save(image_path, 'JPEG')
196
+
197
+ # Save caption
198
+ if "text" in item:
199
+ caption_path = os.path.join(local_path, f"image_{i:06d}.txt")
200
+ with open(caption_path, "w", encoding="utf-8") as f:
201
+ f.write(item["text"])
202
+
203
+ print(f"Downloaded {len(dataset)} items to {local_path}")
204
+
205
+ except Exception as e:
206
+ print(f"Failed to load as structured dataset: {e}")
207
+ print("Attempting to download raw files...")
208
+
209
+ # Download the dataset repository as files
210
+ temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset")
211
+
212
+ # Copy all image and text files to the local path
213
+ import glob
214
+ import shutil
215
+
216
+ print(f"Downloaded repo to: {temp_repo_path}")
217
+ print(f"Contents: {os.listdir(temp_repo_path)}")
218
+
219
+ # Find all image files
220
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG']
221
+ image_files = []
222
+ for ext in image_extensions:
223
+ pattern = os.path.join(temp_repo_path, "**", ext)
224
+ found_files = glob.glob(pattern, recursive=True)
225
+ image_files.extend(found_files)
226
+ print(f"Pattern {pattern} found {len(found_files)} files")
227
+
228
+ # Find all text files
229
+ text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True)
230
+
231
+ print(f"Found {len(image_files)} image files and {len(text_files)} text files")
232
+
233
+ # Copy image files
234
+ for i, img_file in enumerate(image_files):
235
+ dest_path = os.path.join(local_path, f"image_{i:06d}.jpg")
236
+
237
+ # Load and convert image if needed
238
+ try:
239
+ with Image.open(img_file) as image:
240
+ if image.mode == 'RGBA':
241
+ background = Image.new('RGB', image.size, (255, 255, 255))
242
+ background.paste(image, mask=image.split()[-1])
243
+ image = background
244
+ elif image.mode not in ['RGB', 'L']:
245
+ image = image.convert('RGB')
246
+
247
+ image.save(dest_path, 'JPEG')
248
+ except Exception as img_error:
249
+ print(f"Error processing image {img_file}: {img_error}")
250
+ continue
251
+
252
+ # Copy text files (captions)
253
+ for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images
254
+ dest_path = os.path.join(local_path, f"image_{i:06d}.txt")
255
+ try:
256
+ shutil.copy2(txt_file, dest_path)
257
+ except Exception as txt_error:
258
+ print(f"Error copying text file {txt_file}: {txt_error}")
259
+ continue
260
+
261
+ print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}")
262
+
263
+ def create_config(dataset_path: str, output_path: str):
264
+ """Create training configuration"""
265
+ import json
266
+
267
+ # Load config from JSON string and fix boolean/null values for Python
268
+ config_str = """${JSON.stringify(jobConfig, null, 2)}"""
269
+ config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None')
270
+ config = eval(config_str)
271
+
272
+ # Update paths for cloud environment
273
+ config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_path
274
+ config["config"]["process"][0]["training_folder"] = output_path
275
+
276
+ # Remove sqlite_db_path as it's not needed for cloud training
277
+ if "sqlite_db_path" in config["config"]["process"][0]:
278
+ del config["config"]["process"][0]["sqlite_db_path"]
279
+
280
+ # Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies
281
+ if config["config"]["process"][0]["type"] == "ui_trainer":
282
+ config["config"]["process"][0]["type"] = "sd_trainer"
283
+
284
+ return config
285
+
286
+ def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict):
287
+ """Upload trained model to HF Hub with README generation and proper file organization"""
288
+ import tempfile
289
+ import shutil
290
+ import glob
291
+ import re
292
+ import yaml
293
+ from datetime import datetime
294
+ from huggingface_hub import create_repo, upload_file, HfApi
295
+
296
+ try:
297
+ repo_id = f"{namespace}/{model_name}"
298
+
299
+ # Create repository
300
+ create_repo(repo_id=repo_id, token=token, exist_ok=True)
301
+
302
+ print(f"Uploading model to {repo_id}...")
303
+
304
+ # Create temporary directory for organized upload
305
+ with tempfile.TemporaryDirectory() as temp_upload_dir:
306
+ api = HfApi()
307
+
308
+ # 1. Find and upload model files to root directory
309
+ safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True)
310
+ json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True)
311
+ txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True)
312
+
313
+ uploaded_files = []
314
+
315
+ # Upload .safetensors files to root
316
+ for file_path in safetensors_files:
317
+ filename = os.path.basename(file_path)
318
+ print(f"Uploading {filename} to repository root...")
319
+ api.upload_file(
320
+ path_or_fileobj=file_path,
321
+ path_in_repo=filename,
322
+ repo_id=repo_id,
323
+ token=token
324
+ )
325
+ uploaded_files.append(filename)
326
+
327
+ # Upload relevant JSON config files to root (skip metadata.json and other internal files)
328
+ config_files_uploaded = []
329
+ for file_path in json_files:
330
+ filename = os.path.basename(file_path)
331
+ # Only upload important config files, skip internal metadata
332
+ if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']):
333
+ print(f"Uploading {filename} to repository root...")
334
+ api.upload_file(
335
+ path_or_fileobj=file_path,
336
+ path_in_repo=filename,
337
+ repo_id=repo_id,
338
+ token=token
339
+ )
340
+ uploaded_files.append(filename)
341
+ config_files_uploaded.append(filename)
342
+
343
+ # 2. Handle sample images
344
+ samples_uploaded = []
345
+ samples_dir = os.path.join(output_path, "samples")
346
+ if os.path.isdir(samples_dir):
347
+ print("Uploading sample images...")
348
+ # Create samples directory in repo
349
+ for filename in os.listdir(samples_dir):
350
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
351
+ file_path = os.path.join(samples_dir, filename)
352
+ repo_path = f"samples/{filename}"
353
+ api.upload_file(
354
+ path_or_fileobj=file_path,
355
+ path_in_repo=repo_path,
356
+ repo_id=repo_id,
357
+ token=token
358
+ )
359
+ samples_uploaded.append(repo_path)
360
+
361
+ # 3. Generate and upload README.md
362
+ readme_content = generate_model_card_readme(
363
+ repo_id=repo_id,
364
+ config=config,
365
+ model_name=model_name,
366
+ samples_dir=samples_dir if os.path.isdir(samples_dir) else None,
367
+ uploaded_files=uploaded_files
368
+ )
369
+
370
+ # Create README.md file and upload to root
371
+ readme_path = os.path.join(temp_upload_dir, "README.md")
372
+ with open(readme_path, "w", encoding="utf-8") as f:
373
+ f.write(readme_content)
374
+
375
+ print("Uploading README.md to repository root...")
376
+ api.upload_file(
377
+ path_or_fileobj=readme_path,
378
+ path_in_repo="README.md",
379
+ repo_id=repo_id,
380
+ token=token
381
+ )
382
+
383
+ print(f"Model uploaded successfully to https://huggingface.co/{repo_id}")
384
+ print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md")
385
+
386
+ except Exception as e:
387
+ print(f"Failed to upload model: {e}")
388
+ raise e
389
+
390
+ def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str:
391
+ """Generate README.md content for the model card based on AI Toolkit's implementation"""
392
+ import re
393
+ import yaml
394
+ import os
395
+
396
+ try:
397
+ # Extract configuration details
398
+ process_config = config.get("config", {}).get("process", [{}])[0]
399
+ model_config = process_config.get("model", {})
400
+ train_config = process_config.get("train", {})
401
+ sample_config = process_config.get("sample", {})
402
+
403
+ # Gather model info
404
+ base_model = model_config.get("name_or_path", "unknown")
405
+ trigger_word = process_config.get("trigger_word")
406
+ arch = model_config.get("arch", "")
407
+
408
+ # Determine license based on base model
409
+ if "FLUX.1-schnell" in base_model:
410
+ license_info = {"license": "apache-2.0"}
411
+ elif "FLUX.1-dev" in base_model:
412
+ license_info = {
413
+ "license": "other",
414
+ "license_name": "flux-1-dev-non-commercial-license",
415
+ "license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
416
+ }
417
+ else:
418
+ license_info = {"license": "creativeml-openrail-m"}
419
+
420
+ # Generate tags based on model architecture
421
+ tags = ["text-to-image"]
422
+
423
+ if "xl" in arch.lower():
424
+ tags.append("stable-diffusion-xl")
425
+ if "flux" in arch.lower():
426
+ tags.append("flux")
427
+ if "lumina" in arch.lower():
428
+ tags.append("lumina2")
429
+ if "sd3" in arch.lower() or "v3" in arch.lower():
430
+ tags.append("sd3")
431
+
432
+ # Add LoRA-specific tags
433
+ tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"])
434
+
435
+ # Generate widgets from sample images and prompts
436
+ widgets = []
437
+ if samples_dir and os.path.isdir(samples_dir):
438
+ sample_prompts = sample_config.get("samples", [])
439
+ if not sample_prompts:
440
+ # Fallback to old format
441
+ sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])]
442
+
443
+ # Get sample image files
444
+ sample_files = []
445
+ if os.path.isdir(samples_dir):
446
+ for filename in os.listdir(samples_dir):
447
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
448
+ # Parse filename pattern: timestamp__steps_index.jpg
449
+ match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
450
+ if match:
451
+ steps, index = int(match.group(1)), int(match.group(2))
452
+ # Only use samples from final training step
453
+ final_steps = train_config.get("steps", 1000)
454
+ if steps == final_steps:
455
+ sample_files.append((index, f"samples/{filename}"))
456
+
457
+ # Sort by index and create widgets
458
+ sample_files.sort(key=lambda x: x[0])
459
+
460
+ for i, prompt_obj in enumerate(sample_prompts):
461
+ prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj)
462
+ if i < len(sample_files):
463
+ _, image_path = sample_files[i]
464
+ widgets.append({
465
+ "text": prompt,
466
+ "output": {"url": image_path}
467
+ })
468
+
469
+ # Determine torch dtype based on model
470
+ dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16"
471
+
472
+ # Find the main safetensors file for usage example
473
+ main_safetensors = f"{model_name}.safetensors"
474
+ if uploaded_files:
475
+ safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')]
476
+ if safetensors_files:
477
+ main_safetensors = safetensors_files[0]
478
+
479
+ # Construct YAML frontmatter
480
+ frontmatter = {
481
+ "tags": tags,
482
+ "base_model": base_model,
483
+ **license_info
484
+ }
485
+
486
+ if widgets:
487
+ frontmatter["widget"] = widgets
488
+
489
+ if trigger_word:
490
+ frontmatter["instance_prompt"] = trigger_word
491
+
492
+ # Get first prompt for usage example
493
+ usage_prompt = trigger_word or "a beautiful landscape"
494
+ if widgets:
495
+ usage_prompt = widgets[0]["text"]
496
+ elif trigger_word:
497
+ usage_prompt = trigger_word
498
+
499
+ # Construct README content
500
+ trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined."
501
+
502
+ # Build YAML frontmatter string
503
+ frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip()
504
+
505
+ readme_content = f"""---
506
+ {frontmatter_yaml}
507
+ ---
508
+
509
+ # {model_name}
510
+
511
+ Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
512
+
513
+ <Gallery />
514
+
515
+ ## Trigger words
516
+
517
+ {trigger_section}
518
+
519
+ ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc.
520
+
521
+ Weights for this model are available in Safetensors format.
522
+
523
+ [Download]({repo_id}/tree/main) them in the Files & versions tab.
524
+
525
+ ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
526
+
527
+ \`\`\`py
528
+ from diffusers import AutoPipelineForText2Image
529
+ import torch
530
+
531
+ pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda')
532
+ pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}')
533
+ image = pipeline('{usage_prompt}').images[0]
534
+ image.save("my_image.png")
535
+ \`\`\`
536
+
537
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
538
+
539
+ """
540
+ return readme_content
541
+
542
+ except Exception as e:
543
+ print(f"Error generating README: {e}")
544
+ # Fallback simple README
545
+ return f"""# {model_name}
546
+
547
+ Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
548
+
549
+ ## Download model
550
+
551
+ Weights for this model are available in Safetensors format.
552
+
553
+ [Download]({repo_id}/tree/main) them in the Files & versions tab.
554
+ """
555
+
556
+ def main():
557
+ # Setup environment - token comes from HF Jobs secrets
558
+ if "HF_TOKEN" not in os.environ:
559
+ raise ValueError("HF_TOKEN environment variable not set")
560
+
561
+ # Install system dependencies for headless operation
562
+ print("Installing system dependencies...")
563
+ try:
564
+ subprocess.run(["apt-get", "update"], check=True, capture_output=True)
565
+ subprocess.run([
566
+ "apt-get", "install", "-y",
567
+ "libgl1-mesa-glx",
568
+ "libglib2.0-0",
569
+ "libsm6",
570
+ "libxext6",
571
+ "libxrender-dev",
572
+ "libgomp1",
573
+ "ffmpeg"
574
+ ], check=True, capture_output=True)
575
+ print("System dependencies installed successfully")
576
+ except subprocess.CalledProcessError as e:
577
+ print(f"Failed to install system dependencies: {e}")
578
+ print("Continuing without system dependencies...")
579
+
580
+ # Setup ai-toolkit
581
+ toolkit_dir = setup_ai_toolkit()
582
+
583
+ # Create temporary directories
584
+ with tempfile.TemporaryDirectory() as temp_dir:
585
+ dataset_path = os.path.join(temp_dir, "dataset")
586
+ output_path = os.path.join(temp_dir, "output")
587
+
588
+ # Download dataset
589
+ download_dataset("${datasetRepo}", dataset_path)
590
+
591
+ # Create config
592
+ config = create_config(dataset_path, output_path)
593
+ config_path = os.path.join(temp_dir, "config.yaml")
594
+
595
+ with open(config_path, "w") as f:
596
+ yaml.dump(config, f, default_flow_style=False)
597
+
598
+ # Run training
599
+ print("Starting training...")
600
+ os.chdir(toolkit_dir)
601
+
602
+ subprocess.run([
603
+ sys.executable, "run.py",
604
+ config_path
605
+ ], check=True)
606
+
607
+ print("Training completed!")
608
+
609
+ # Upload results
610
+ model_name = f"${jobConfig.config.name}-lora"
611
+ upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config)
612
+
613
+ if __name__ == "__main__":
614
+ main()
615
+ `;
616
+ }
617
+
618
+ async function submitHFJobUV(token: string, hardware: string, scriptPath: string): Promise<string> {
619
+ return new Promise((resolve, reject) => {
620
+ // Ensure token is available
621
+ if (!token) {
622
+ reject(new Error('HF_TOKEN is required'));
623
+ return;
624
+ }
625
+
626
+ console.log('Setting up environment with HF_TOKEN for job submission');
627
+ console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach ${scriptPath}`);
628
+
629
+ // Use hf jobs uv run command with timeout and detach to get job ID
630
+ const childProcess = spawn('hf', [
631
+ 'jobs', 'uv', 'run',
632
+ '--flavor', hardware,
633
+ '--timeout', '5h',
634
+ '--secrets', 'HF_TOKEN',
635
+ '--detach',
636
+ scriptPath
637
+ ], {
638
+ env: {
639
+ ...process.env,
640
+ HF_TOKEN: token
641
+ }
642
+ });
643
+
644
+ let output = '';
645
+ let error = '';
646
+
647
+ childProcess.stdout.on('data', (data) => {
648
+ const text = data.toString();
649
+ output += text;
650
+ console.log('HF Jobs stdout:', text);
651
+ });
652
+
653
+ childProcess.stderr.on('data', (data) => {
654
+ const text = data.toString();
655
+ error += text;
656
+ console.log('HF Jobs stderr:', text);
657
+ });
658
+
659
+ childProcess.on('close', (code) => {
660
+ console.log('HF Jobs process closed with code:', code);
661
+ console.log('Full output:', output);
662
+ console.log('Full error:', error);
663
+
664
+ if (code === 0) {
665
+ // With --detach flag, the output should be just the job ID
666
+ const fullText = (output + ' ' + error).trim();
667
+
668
+ // Updated patterns to handle variable-length hex job IDs (16-24+ characters)
669
+ const jobIdPatterns = [
670
+ /Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac"
671
+ /job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac"
672
+ /Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac"
673
+ /created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac"
674
+ /submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac"
675
+ /https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern
676
+ /([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string
677
+ ];
678
+
679
+ let jobId = 'unknown';
680
+
681
+ for (const pattern of jobIdPatterns) {
682
+ const match = fullText.match(pattern);
683
+ if (match && match[1] && match[1] !== 'started') {
684
+ jobId = match[1];
685
+ console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`);
686
+ break;
687
+ }
688
+ }
689
+
690
+ resolve(jobId);
691
+ } else {
692
+ reject(new Error(error || output || 'Failed to submit job'));
693
+ }
694
+ });
695
+
696
+ childProcess.on('error', (err) => {
697
+ console.error('HF Jobs process error:', err);
698
+ reject(new Error(`Process error: ${err.message}`));
699
+ });
700
+ });
701
+ }
702
+
703
+ async function checkHFJobStatus(token: string, jobId: string): Promise<any> {
704
+ return new Promise((resolve, reject) => {
705
+ console.log(`Checking HF Job status for: ${jobId}`);
706
+
707
+ const childProcess = spawn('hf', [
708
+ 'jobs', 'inspect', jobId
709
+ ], {
710
+ env: {
711
+ ...process.env,
712
+ HF_TOKEN: token
713
+ }
714
+ });
715
+
716
+ let output = '';
717
+ let error = '';
718
+
719
+ childProcess.stdout.on('data', (data) => {
720
+ const text = data.toString();
721
+ output += text;
722
+ });
723
+
724
+ childProcess.stderr.on('data', (data) => {
725
+ const text = data.toString();
726
+ error += text;
727
+ });
728
+
729
+ childProcess.on('close', (code) => {
730
+ if (code === 0) {
731
+ try {
732
+ // Parse the JSON output from hf jobs inspect
733
+ const jobInfo = JSON.parse(output);
734
+ if (Array.isArray(jobInfo) && jobInfo.length > 0) {
735
+ const job = jobInfo[0];
736
+ resolve({
737
+ id: job.id,
738
+ status: job.status?.stage || 'UNKNOWN',
739
+ message: job.status?.message,
740
+ created_at: job.created_at,
741
+ flavor: job.flavor,
742
+ url: job.url,
743
+ });
744
+ } else {
745
+ reject(new Error('Invalid job info response'));
746
+ }
747
+ } catch (parseError: any) {
748
+ console.error('Failed to parse job status:', parseError, output);
749
+ reject(new Error('Failed to parse job status'));
750
+ }
751
+ } else {
752
+ reject(new Error(error || output || 'Failed to check job status'));
753
+ }
754
+ });
755
+
756
+ childProcess.on('error', (err) => {
757
+ console.error('HF Jobs inspect process error:', err);
758
+ reject(new Error(`Process error: ${err.message}`));
759
+ });
760
+ });
761
+ }
src/app/api/img/[...imagePath]/route.ts ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import path from 'path';
5
+ import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings';
6
+
7
+ export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
8
+ const { imagePath } = await params;
9
+ try {
10
+ // Decode the path
11
+ const filepath = decodeURIComponent(imagePath);
12
+
13
+ // Get allowed directories
14
+ const datasetRoot = await getDatasetsRoot();
15
+ const trainingRoot = await getTrainingFolder();
16
+ const dataRoot = await getDataRoot();
17
+
18
+ const allowedDirs = [datasetRoot, trainingRoot, dataRoot];
19
+
20
+ // Security check: Ensure path is in allowed directory
21
+ const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');
22
+
23
+ if (!isAllowed) {
24
+ console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`);
25
+ return new NextResponse('Access denied', { status: 403 });
26
+ }
27
+
28
+ // Check if file exists
29
+ if (!fs.existsSync(filepath)) {
30
+ console.warn(`File not found: ${filepath}`);
31
+ return new NextResponse('File not found', { status: 404 });
32
+ }
33
+
34
+ // Get file info
35
+ const stat = fs.statSync(filepath);
36
+ if (!stat.isFile()) {
37
+ return new NextResponse('Not a file', { status: 400 });
38
+ }
39
+
40
+ // Determine content type
41
+ const ext = path.extname(filepath).toLowerCase();
42
+ const contentTypeMap: { [key: string]: string } = {
43
+ // Images
44
+ '.jpg': 'image/jpeg',
45
+ '.jpeg': 'image/jpeg',
46
+ '.png': 'image/png',
47
+ '.gif': 'image/gif',
48
+ '.webp': 'image/webp',
49
+ '.svg': 'image/svg+xml',
50
+ '.bmp': 'image/bmp',
51
+ // Videos
52
+ '.mp4': 'video/mp4',
53
+ '.avi': 'video/x-msvideo',
54
+ '.mov': 'video/quicktime',
55
+ '.mkv': 'video/x-matroska',
56
+ '.wmv': 'video/x-ms-wmv',
57
+ '.m4v': 'video/x-m4v',
58
+ '.flv': 'video/x-flv'
59
+ };
60
+
61
+ const contentType = contentTypeMap[ext] || 'application/octet-stream';
62
+
63
+ // Read file as buffer
64
+ const fileBuffer = fs.readFileSync(filepath);
65
+
66
+ // Return file with appropriate headers
67
+ return new NextResponse(fileBuffer, {
68
+ headers: {
69
+ 'Content-Type': contentType,
70
+ 'Content-Length': String(stat.size),
71
+ 'Cache-Control': 'public, max-age=86400',
72
+ },
73
+ });
74
+ } catch (error) {
75
+ console.error('Error serving image:', error);
76
+ return new NextResponse('Internal Server Error', { status: 500 });
77
+ }
78
+ }
src/app/api/img/caption/route.ts ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import { getDatasetsRoot } from '@/server/settings';
4
+
5
+ export async function POST(request: Request) {
6
+ try {
7
+ const body = await request.json();
8
+ const { imgPath, caption } = body;
9
+ let datasetsPath = await getDatasetsRoot();
10
+ // make sure the dataset path is in the image path
11
+ if (!imgPath.startsWith(datasetsPath)) {
12
+ return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
13
+ }
14
+
15
+ // if img doesnt exist, ignore
16
+ if (!fs.existsSync(imgPath)) {
17
+ return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
18
+ }
19
+
20
+ // check for caption
21
+ const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
22
+ // save caption to file
23
+ fs.writeFileSync(captionPath, caption);
24
+
25
+ return NextResponse.json({ success: true });
26
+ } catch (error) {
27
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
28
+ }
29
+ }
src/app/api/img/delete/route.ts ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import { getDatasetsRoot } from '@/server/settings';
4
+
5
+ export async function POST(request: Request) {
6
+ try {
7
+ const body = await request.json();
8
+ const { imgPath } = body;
9
+ let datasetsPath = await getDatasetsRoot();
10
+ // make sure the dataset path is in the image path
11
+ if (!imgPath.startsWith(datasetsPath)) {
12
+ return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
13
+ }
14
+
15
+ // if img doesnt exist, ignore
16
+ if (!fs.existsSync(imgPath)) {
17
+ return NextResponse.json({ success: true });
18
+ }
19
+
20
+ // delete it and return success
21
+ fs.unlinkSync(imgPath);
22
+
23
+ // check for caption
24
+ const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
25
+ if (fs.existsSync(captionPath)) {
26
+ // delete caption file
27
+ fs.unlinkSync(captionPath);
28
+ }
29
+
30
+ return NextResponse.json({ success: true });
31
+ } catch (error) {
32
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
33
+ }
34
+ }
src/app/api/img/upload/route.ts ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/app/api/datasets/upload/route.ts
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import { writeFile, mkdir } from 'fs/promises';
4
+ import { join } from 'path';
5
+ import { getDataRoot } from '@/server/settings';
6
+ import {v4 as uuidv4} from 'uuid';
7
+
8
+ export async function POST(request: NextRequest) {
9
+ try {
10
+ const dataRoot = await getDataRoot();
11
+ if (!dataRoot) {
12
+ return NextResponse.json({ error: 'Data root path not found' }, { status: 500 });
13
+ }
14
+ const imgRoot = join(dataRoot, 'images');
15
+
16
+
17
+ const formData = await request.formData();
18
+ const files = formData.getAll('files');
19
+
20
+ if (!files || files.length === 0) {
21
+ return NextResponse.json({ error: 'No files provided' }, { status: 400 });
22
+ }
23
+
24
+ // make it recursive if it doesn't exist
25
+ await mkdir(imgRoot, { recursive: true });
26
+ const savedFiles = await Promise.all(
27
+ files.map(async (file: any) => {
28
+ const bytes = await file.arrayBuffer();
29
+ const buffer = Buffer.from(bytes);
30
+
31
+ const extension = file.name.split('.').pop() || 'jpg';
32
+
33
+ // Clean filename and ensure it's unique
34
+ const fileName = `${uuidv4()}`; // Use UUID for unique file names
35
+ const filePath = join(imgRoot, `${fileName}.${extension}`);
36
+
37
+ await writeFile(filePath, buffer);
38
+ return filePath;
39
+ }),
40
+ );
41
+
42
+ return NextResponse.json({
43
+ message: 'Files uploaded successfully',
44
+ files: savedFiles,
45
+ });
46
+ } catch (error) {
47
+ console.error('Upload error:', error);
48
+ return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
49
+ }
50
+ }
51
+
52
+ // Increase payload size limit (default is 4mb)
53
+ export const config = {
54
+ api: {
55
+ bodyParser: false,
56
+ responseLimit: '50mb',
57
+ },
58
+ };
src/app/api/jobs/[jobID]/delete/route.ts ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import { getTrainingFolder } from '@/server/settings';
4
+ import path from 'path';
5
+ import fs from 'fs';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ const trainingRoot = await getTrainingFolder();
21
+ const trainingFolder = path.join(trainingRoot, job.name);
22
+
23
+ if (fs.existsSync(trainingFolder)) {
24
+ fs.rmdirSync(trainingFolder, { recursive: true });
25
+ }
26
+
27
+ await prisma.job.delete({
28
+ where: { id: jobID },
29
+ });
30
+
31
+ return NextResponse.json(job);
32
+ }
src/app/api/jobs/[jobID]/files/route.ts ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import path from 'path';
4
+ import fs from 'fs';
5
+ import { getTrainingFolder } from '@/server/settings';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ const trainingFolder = await getTrainingFolder();
21
+ const jobFolder = path.join(trainingFolder, job.name);
22
+
23
+ if (!fs.existsSync(jobFolder)) {
24
+ return NextResponse.json({ files: [] });
25
+ }
26
+
27
+ // find all safetensors files in the job folder
28
+ let files = fs
29
+ .readdirSync(jobFolder)
30
+ .filter(file => {
31
+ return file.endsWith('.safetensors');
32
+ })
33
+ .map(file => {
34
+ return path.join(jobFolder, file);
35
+ })
36
+ .sort();
37
+
38
+ // get the file size for each file
39
+ const fileObjects = files.map(file => {
40
+ const stats = fs.statSync(file);
41
+ return {
42
+ path: file,
43
+ size: stats.size,
44
+ };
45
+ });
46
+
47
+ return NextResponse.json({ files: fileObjects });
48
+ }
src/app/api/jobs/[jobID]/log/route.ts ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import path from 'path';
4
+ import fs from 'fs';
5
+ import { getTrainingFolder } from '@/server/settings';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ const trainingFolder = await getTrainingFolder();
21
+ const jobFolder = path.join(trainingFolder, job.name);
22
+ const logPath = path.join(jobFolder, 'log.txt');
23
+
24
+ if (!fs.existsSync(logPath)) {
25
+ return NextResponse.json({ log: '' });
26
+ }
27
+ let log = '';
28
+ try {
29
+ log = fs.readFileSync(logPath, 'utf-8');
30
+ } catch (error) {
31
+ console.error('Error reading log file:', error);
32
+ log = 'Error reading log file';
33
+ }
34
+ return NextResponse.json({ log: log });
35
+ }
src/app/api/jobs/[jobID]/samples/route.ts ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import path from 'path';
4
+ import fs from 'fs';
5
+ import { getTrainingFolder } from '@/server/settings';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ // setup the training
21
+ const trainingFolder = await getTrainingFolder();
22
+
23
+ const samplesFolder = path.join(trainingFolder, job.name, 'samples');
24
+ if (!fs.existsSync(samplesFolder)) {
25
+ return NextResponse.json({ samples: [] });
26
+ }
27
+
28
+ // find all img (png, jpg, jpeg) files in the samples folder
29
+ const samples = fs
30
+ .readdirSync(samplesFolder)
31
+ .filter(file => {
32
+ return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp');
33
+ })
34
+ .map(file => {
35
+ return path.join(samplesFolder, file);
36
+ })
37
+ .sort();
38
+
39
+ return NextResponse.json({ samples });
40
+ }
src/app/api/jobs/[jobID]/start/route.ts ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import { TOOLKIT_ROOT } from '@/paths';
4
+ import { spawn } from 'child_process';
5
+ import path from 'path';
6
+ import fs from 'fs';
7
+ import os from 'os';
8
+ import { getTrainingFolder, getHFToken } from '@/server/settings';
9
+ const isWindows = process.platform === 'win32';
10
+
11
+ const prisma = new PrismaClient();
12
+
13
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
14
+ const { jobID } = await params;
15
+
16
+ const job = await prisma.job.findUnique({
17
+ where: { id: jobID },
18
+ });
19
+
20
+ if (!job) {
21
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
22
+ }
23
+
24
+ // update job status to 'running'
25
+ await prisma.job.update({
26
+ where: { id: jobID },
27
+ data: {
28
+ status: 'running',
29
+ stop: false,
30
+ info: 'Starting job...',
31
+ },
32
+ });
33
+
34
+ // setup the training
35
+ const trainingRoot = await getTrainingFolder();
36
+
37
+ const trainingFolder = path.join(trainingRoot, job.name);
38
+ if (!fs.existsSync(trainingFolder)) {
39
+ fs.mkdirSync(trainingFolder, { recursive: true });
40
+ }
41
+
42
+ // make the config file
43
+ const configPath = path.join(trainingFolder, '.job_config.json');
44
+
45
+ //log to path
46
+ const logPath = path.join(trainingFolder, 'log.txt');
47
+
48
+ try {
49
+ // if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num
50
+ // if the log path does not exist, create it
51
+ if (fs.existsSync(logPath)) {
52
+ const logsFolder = path.join(trainingFolder, 'logs');
53
+ if (!fs.existsSync(logsFolder)) {
54
+ fs.mkdirSync(logsFolder, { recursive: true });
55
+ }
56
+
57
+ let num = 0;
58
+ while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) {
59
+ num++;
60
+ }
61
+
62
+ fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`));
63
+ }
64
+ } catch (e) {
65
+ console.error('Error moving log file:', e);
66
+ }
67
+
68
+ // update the config dataset path
69
+ const jobConfig = JSON.parse(job.job_config);
70
+ jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
71
+
72
+ // write the config file
73
+ fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
74
+
75
+ let pythonPath = 'python';
76
+ // use .venv or venv if it exists
77
+ if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
78
+ if (isWindows) {
79
+ pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
80
+ } else {
81
+ pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
82
+ }
83
+ } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
84
+ if (isWindows) {
85
+ pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
86
+ } else {
87
+ pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
88
+ }
89
+ }
90
+
91
+ const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
92
+ if (!fs.existsSync(runFilePath)) {
93
+ return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
94
+ }
95
+
96
+ const additionalEnv: any = {
97
+ AITK_JOB_ID: jobID,
98
+ CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
99
+ IS_AI_TOOLKIT_UI: '1'
100
+ };
101
+
102
+ // HF_TOKEN
103
+ const hfToken = await getHFToken();
104
+ if (hfToken && hfToken.trim() !== '') {
105
+ additionalEnv.HF_TOKEN = hfToken;
106
+ }
107
+
108
+ // Add the --log argument to the command
109
+ const args = [runFilePath, configPath, '--log', logPath];
110
+
111
+ try {
112
+ let subprocess;
113
+
114
+ if (isWindows) {
115
+ // For Windows, use 'cmd.exe' to open a new command window
116
+ subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, ...args], {
117
+ env: {
118
+ ...process.env,
119
+ ...additionalEnv,
120
+ },
121
+ cwd: TOOLKIT_ROOT,
122
+ windowsHide: false,
123
+ });
124
+ } else {
125
+ // For non-Windows platforms
126
+ subprocess = spawn(pythonPath, args, {
127
+ detached: true,
128
+ stdio: ['ignore', 'pipe', 'pipe'], // Changed from 'ignore' to capture output
129
+ env: {
130
+ ...process.env,
131
+ ...additionalEnv,
132
+ },
133
+ cwd: TOOLKIT_ROOT,
134
+ });
135
+ }
136
+
137
+ // Start monitoring in the background without blocking the response
138
+ const monitorProcess = async () => {
139
+ const startTime = Date.now();
140
+ let errorOutput = '';
141
+ let stdoutput = '';
142
+
143
+ if (subprocess.stderr) {
144
+ subprocess.stderr.on('data', data => {
145
+ errorOutput += data.toString();
146
+ });
147
+ subprocess.stdout.on('data', data => {
148
+ stdoutput += data.toString();
149
+ // truncate to only get the last 500 characters
150
+ if (stdoutput.length > 500) {
151
+ stdoutput = stdoutput.substring(stdoutput.length - 500);
152
+ }
153
+ });
154
+ }
155
+
156
+ subprocess.on('exit', async code => {
157
+ const currentTime = Date.now();
158
+ const duration = (currentTime - startTime) / 1000;
159
+ console.log(`Job ${jobID} exited with code ${code} after ${duration} seconds.`);
160
+ // wait for 5 seconds to give it time to stop itself. It id still has a status of running in the db, update it to stopped
161
+ await new Promise(resolve => setTimeout(resolve, 5000));
162
+ const updatedJob = await prisma.job.findUnique({
163
+ where: { id: jobID },
164
+ });
165
+ if (updatedJob?.status === 'running') {
166
+ let errorString = errorOutput;
167
+ if (errorString.trim() === '') {
168
+ errorString = stdoutput;
169
+ }
170
+ await prisma.job.update({
171
+ where: { id: jobID },
172
+ data: {
173
+ status: 'error',
174
+ info: `Error launching job: ${errorString.substring(0, 500)}`,
175
+ },
176
+ });
177
+ }
178
+ });
179
+
180
+ // Wait 30 seconds before releasing the process
181
+ await new Promise(resolve => setTimeout(resolve, 30000));
182
+ // Detach the process for non-Windows systems
183
+ if (!isWindows && subprocess.unref) {
184
+ subprocess.unref();
185
+ }
186
+ };
187
+
188
+ // Start the monitoring without awaiting it
189
+ monitorProcess().catch(err => {
190
+ console.error(`Error in process monitoring for job ${jobID}:`, err);
191
+ });
192
+
193
+ // Return the response immediately
194
+ return NextResponse.json(job);
195
+ } catch (error: any) {
196
+ // Handle any exceptions during process launch
197
+ console.error('Error launching process:', error);
198
+
199
+ await prisma.job.update({
200
+ where: { id: jobID },
201
+ data: {
202
+ status: 'error',
203
+ info: `Error launching job: ${error?.message || 'Unknown error'}`,
204
+ },
205
+ });
206
+
207
+ return NextResponse.json(
208
+ {
209
+ error: 'Failed to launch job process',
210
+ details: error?.message || 'Unknown error',
211
+ },
212
+ { status: 500 },
213
+ );
214
+ }
215
+ }
src/app/api/jobs/[jobID]/stop/route.ts ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+
4
+ const prisma = new PrismaClient();
5
+
6
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
7
+ const { jobID } = await params;
8
+
9
+ const job = await prisma.job.findUnique({
10
+ where: { id: jobID },
11
+ });
12
+
13
+ // update job status to 'running'
14
+ await prisma.job.update({
15
+ where: { id: jobID },
16
+ data: {
17
+ stop: true,
18
+ info: 'Stopping job...',
19
+ },
20
+ });
21
+
22
+ return NextResponse.json(job);
23
+ }
src/app/api/jobs/route.ts ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+
4
+ const prisma = new PrismaClient();
5
+
6
+ export async function GET(request: Request) {
7
+ const { searchParams } = new URL(request.url);
8
+ const id = searchParams.get('id');
9
+
10
+ try {
11
+ if (id) {
12
+ const job = await prisma.job.findUnique({
13
+ where: { id },
14
+ });
15
+ return NextResponse.json(job);
16
+ }
17
+
18
+ const jobs = await prisma.job.findMany({
19
+ orderBy: { created_at: 'desc' },
20
+ });
21
+ return NextResponse.json({ jobs: jobs });
22
+ } catch (error) {
23
+ console.error(error);
24
+ return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
25
+ }
26
+ }
27
+
28
+ export async function POST(request: Request) {
29
+ try {
30
+ const body = await request.json();
31
+ const { id, name, job_config, gpu_ids } = body;
32
+
33
+ // Ensure gpu_ids is never null/undefined - provide default value
34
+ const safeGpuIds = gpu_ids || '0';
35
+
36
+ if (id) {
37
+ // Update existing training
38
+ const training = await prisma.job.update({
39
+ where: { id },
40
+ data: {
41
+ name,
42
+ gpu_ids: safeGpuIds,
43
+ job_config: JSON.stringify(job_config),
44
+ },
45
+ });
46
+ return NextResponse.json(training);
47
+ } else {
48
+ // Create new training
49
+ const training = await prisma.job.create({
50
+ data: {
51
+ name,
52
+ gpu_ids: safeGpuIds,
53
+ job_config: JSON.stringify(job_config),
54
+ },
55
+ });
56
+ return NextResponse.json(training);
57
+ }
58
+ } catch (error: any) {
59
+ if (error.code === 'P2002') {
60
+ // Handle unique constraint violation, 409=Conflict
61
+ return NextResponse.json({ error: 'Job name already exists' }, { status: 409 });
62
+ }
63
+ console.error(error);
64
+ // Handle other errors
65
+ return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 });
66
+ }
67
+ }
src/app/api/settings/route.ts ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
4
+ import { flushCache } from '@/server/settings';
5
+
6
+ const prisma = new PrismaClient();
7
+
8
+ export async function GET() {
9
+ try {
10
+ const settings = await prisma.settings.findMany();
11
+ const settingsObject = settings.reduce((acc: any, setting) => {
12
+ acc[setting.key] = setting.value;
13
+ return acc;
14
+ }, {});
15
+ // if TRAINING_FOLDER is not set, use default
16
+ if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
17
+ settingsObject.TRAINING_FOLDER = defaultTrainFolder;
18
+ }
19
+ // if DATASETS_FOLDER is not set, use default
20
+ if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') {
21
+ settingsObject.DATASETS_FOLDER = defaultDatasetsFolder;
22
+ }
23
+ return NextResponse.json(settingsObject);
24
+ } catch (error) {
25
+ return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
26
+ }
27
+ }
28
+
29
+ export async function POST(request: Request) {
30
+ try {
31
+ const body = await request.json();
32
+ const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body;
33
+
34
+ // Upsert both settings
35
+ await Promise.all([
36
+ prisma.settings.upsert({
37
+ where: { key: 'HF_TOKEN' },
38
+ update: { value: HF_TOKEN },
39
+ create: { key: 'HF_TOKEN', value: HF_TOKEN },
40
+ }),
41
+ prisma.settings.upsert({
42
+ where: { key: 'TRAINING_FOLDER' },
43
+ update: { value: TRAINING_FOLDER },
44
+ create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER },
45
+ }),
46
+ prisma.settings.upsert({
47
+ where: { key: 'DATASETS_FOLDER' },
48
+ update: { value: DATASETS_FOLDER },
49
+ create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER },
50
+ }),
51
+ ]);
52
+
53
+ flushCache();
54
+
55
+ return NextResponse.json({ success: true });
56
+ } catch (error) {
57
+ return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });
58
+ }
59
+ }
src/app/api/zip/route.ts ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import fsp from 'fs/promises';
5
+ import path from 'path';
6
+ import archiver from 'archiver';
7
+ import { getTrainingFolder } from '@/server/settings';
8
+
9
+ export const runtime = 'nodejs'; // ensure Node APIs are available
10
+ export const dynamic = 'force-dynamic'; // long-running, non-cached
11
+
12
+ type PostBody = {
13
+ zipTarget: 'samples'; //only samples for now
14
+ jobName: string;
15
+ };
16
+
17
+ async function resolveSafe(p: string) {
18
+ // resolve symlinks + normalize
19
+ return await fsp.realpath(p);
20
+ }
21
+
22
+ export async function POST(request: NextRequest) {
23
+ try {
24
+ const body = (await request.json()) as PostBody;
25
+ if (!body || !body.jobName) {
26
+ return NextResponse.json({ error: 'jobName is required' }, { status: 400 });
27
+ }
28
+
29
+ const trainingRoot = await resolveSafe(await getTrainingFolder());
30
+ const folderPath = await resolveSafe(path.join(trainingRoot, body.jobName, 'samples'));
31
+ const outputPath = path.resolve(trainingRoot, body.jobName, 'samples.zip');
32
+
33
+ // Must be a directory
34
+ let stat: fs.Stats;
35
+ try {
36
+ stat = await fsp.stat(folderPath);
37
+ } catch {
38
+ return new NextResponse('Folder not found', { status: 404 });
39
+ }
40
+ if (!stat.isDirectory()) {
41
+ return new NextResponse('Not a directory', { status: 400 });
42
+ }
43
+
44
+ // delete current one if it exists
45
+ if (fs.existsSync(outputPath)) {
46
+ await fsp.unlink(outputPath);
47
+ }
48
+
49
+ // Create write stream & archive
50
+ await new Promise<void>((resolve, reject) => {
51
+ const output = fs.createWriteStream(outputPath);
52
+ const archive = archiver('zip', { zlib: { level: 9 } });
53
+
54
+ output.on('close', () => resolve());
55
+ output.on('error', reject);
56
+ archive.on('error', reject);
57
+
58
+ archive.pipe(output);
59
+
60
+ // Add the directory contents (place them under the folder's base name in the zip)
61
+ const rootName = path.basename(folderPath);
62
+ archive.directory(folderPath, rootName);
63
+
64
+ archive.finalize().catch(reject);
65
+ });
66
+
67
+ // Return the absolute path so your existing /api/files/[...filePath] can serve it
68
+ // Example download URL (client-side): `/api/files/${encodeURIComponent(resolvedOutPath)}`
69
+ return NextResponse.json({
70
+ ok: true,
71
+ zipPath: outputPath,
72
+ fileName: path.basename(outputPath),
73
+ });
74
+ } catch (err) {
75
+ console.error('Zip error:', err);
76
+ return new NextResponse('Internal Server Error', { status: 500 });
77
+ }
78
+ }
src/app/apple-icon.png ADDED
src/app/dashboard/page.tsx ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import JobsTable from '@/components/JobsTable';
4
+ import { TopBar, MainContent } from '@/components/layout';
5
+ import Link from 'next/link';
6
+ import { useAuth } from '@/contexts/AuthContext';
7
+ import HFLoginButton from '@/components/HFLoginButton';
8
+
9
+ export default function Dashboard() {
10
+ const { status: authStatus, namespace } = useAuth();
11
+ const isAuthenticated = authStatus === 'authenticated';
12
+
13
+ return (
14
+ <>
15
+ <TopBar>
16
+ <div>
17
+ <h1 className="text-lg">Dashboard</h1>
18
+ </div>
19
+ <div className="flex-1" />
20
+ </TopBar>
21
+ <MainContent>
22
+ <div className="border border-gray-800 rounded-xl bg-gray-900 p-6 flex flex-col gap-4">
23
+ <div>
24
+ <h2 className="text-xl font-semibold text-gray-100">
25
+ {isAuthenticated ? `Welcome back, ${namespace || 'creator'}!` : 'Welcome to Ostris AI Toolkit'}
26
+ </h2>
27
+ <p className="text-sm text-gray-400 mt-2">
28
+ {isAuthenticated
29
+ ? 'You are signed in with Hugging Face and can manage jobs, datasets, and submissions.'
30
+ : 'Authenticate with Hugging Face or add a personal access token to create jobs, upload datasets, and launch training.'}
31
+ </p>
32
+ </div>
33
+ {isAuthenticated ? (
34
+ <div className="flex flex-wrap items-center gap-3 text-sm">
35
+ <Link
36
+ href="/jobs/new"
37
+ className="px-4 py-2 rounded-md bg-blue-600 hover:bg-blue-500 text-white transition-colors"
38
+ >
39
+ Create a Training Job
40
+ </Link>
41
+ <Link
42
+ href="/datasets"
43
+ className="px-4 py-2 rounded-md bg-gray-800 hover:bg-gray-700 text-gray-200 transition-colors"
44
+ >
45
+ Manage Datasets
46
+ </Link>
47
+ <Link
48
+ href="/settings"
49
+ className="px-4 py-2 rounded-md border border-gray-700 text-gray-300 hover:border-gray-600 transition-colors"
50
+ >
51
+ Settings
52
+ </Link>
53
+ </div>
54
+ ) : (
55
+ <div className="flex flex-wrap items-center gap-3 text-sm">
56
+ <HFLoginButton size="md" />
57
+ <Link
58
+ href="/settings"
59
+ className="text-xs text-blue-400 hover:text-blue-300"
60
+ >
61
+ Or manage tokens in Settings
62
+ </Link>
63
+ </div>
64
+ )}
65
+ </div>
66
+
67
+ <div className="w-full mt-6">
68
+ <div className="flex justify-between items-center mb-2">
69
+ <h1 className="text-md">Active Jobs</h1>
70
+ <div className="text-xs text-gray-500">
71
+ <Link href="/jobs">View All</Link>
72
+ </div>
73
+ </div>
74
+ {isAuthenticated ? (
75
+ <JobsTable onlyActive />
76
+ ) : (
77
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm">
78
+ Sign in with Hugging Face or add an access token in Settings to view and manage jobs.
79
+ </div>
80
+ )}
81
+ </div>
82
+ </MainContent>
83
+ </>
84
+ );
85
+ }
src/app/datasets/[datasetName]/page.tsx ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useEffect, useState, use, useMemo } from 'react';
4
+ import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu';
5
+ import { FaChevronLeft } from 'react-icons/fa';
6
+ import DatasetImageCard from '@/components/DatasetImageCard';
7
+ import { Button } from '@headlessui/react';
8
+ import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
9
+ import { TopBar, MainContent } from '@/components/layout';
10
+ import { apiClient } from '@/utils/api';
11
+ import FullscreenDropOverlay from '@/components/FullscreenDropOverlay';
12
+ import { useRouter } from 'next/navigation';
13
+ import { usingBrowserDb } from '@/utils/env';
14
+ import { hasUserDataset } from '@/utils/storage/datasetStorage';
15
+ import { useAuth } from '@/contexts/AuthContext';
16
+ import HFLoginButton from '@/components/HFLoginButton';
17
+ import Link from 'next/link';
18
+
19
+ export default function DatasetPage({ params }: { params: { datasetName: string } }) {
20
+ const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
21
+ const usableParams = use(params as any) as { datasetName: string };
22
+ const datasetName = usableParams.datasetName;
23
+ const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
24
+ const router = useRouter();
25
+ const { status: authStatus } = useAuth();
26
+ const isAuthenticated = authStatus === 'authenticated';
27
+ const hasDatasetEntry = !usingBrowserDb || hasUserDataset(datasetName);
28
+ const allowAccess = hasDatasetEntry && isAuthenticated;
29
+
30
+ const refreshImageList = (dbName: string) => {
31
+ setStatus('loading');
32
+ console.log('Fetching images for dataset:', dbName);
33
+ apiClient
34
+ .post('/api/datasets/listImages', { datasetName: dbName })
35
+ .then((res: any) => {
36
+ const data = res.data;
37
+ console.log('Images:', data.images);
38
+ // sort
39
+ data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
40
+ setImgList(data.images);
41
+ setStatus('success');
42
+ })
43
+ .catch(error => {
44
+ console.error('Error fetching images:', error);
45
+ setStatus('error');
46
+ });
47
+ };
48
+ useEffect(() => {
49
+ if (!datasetName) {
50
+ return;
51
+ }
52
+
53
+ if (!isAuthenticated) {
54
+ return;
55
+ }
56
+
57
+ if (!hasDatasetEntry) {
58
+ setImgList([]);
59
+ setStatus('error');
60
+ router.replace('/datasets');
61
+ return;
62
+ }
63
+
64
+ refreshImageList(datasetName);
65
+ }, [datasetName, hasDatasetEntry, isAuthenticated, router]);
66
+
67
+ if (!allowAccess) {
68
+ return (
69
+ <>
70
+ <TopBar>
71
+ <div>
72
+ <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
73
+ <FaChevronLeft />
74
+ </Button>
75
+ </div>
76
+ <div>
77
+ <h1 className="text-lg">Dataset: {datasetName}</h1>
78
+ </div>
79
+ <div className="flex-1"></div>
80
+ </TopBar>
81
+ <MainContent>
82
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
83
+ <p>You need to sign in with Hugging Face or provide a valid token to view this dataset.</p>
84
+ <div className="flex items-center gap-3">
85
+ <HFLoginButton size="sm" />
86
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
87
+ Manage authentication in Settings
88
+ </Link>
89
+ </div>
90
+ </div>
91
+ </MainContent>
92
+ </>
93
+ );
94
+ }
95
+
96
+ const PageInfoContent = useMemo(() => {
97
+ let icon = null;
98
+ let text = '';
99
+ let subtitle = '';
100
+ let showIt = false;
101
+ let bgColor = '';
102
+ let textColor = '';
103
+ let iconColor = '';
104
+
105
+ if (status == 'loading') {
106
+ icon = <LuLoader className="animate-spin w-8 h-8" />;
107
+ text = 'Loading Images';
108
+ subtitle = 'Please wait while we fetch your dataset images...';
109
+ showIt = true;
110
+ bgColor = 'bg-gray-50 dark:bg-gray-800/50';
111
+ textColor = 'text-gray-900 dark:text-gray-100';
112
+ iconColor = 'text-gray-500 dark:text-gray-400';
113
+ }
114
+ if (status == 'error') {
115
+ icon = <LuBan className="w-8 h-8" />;
116
+ text = 'Error Loading Images';
117
+ subtitle = 'There was a problem fetching the images. Please try refreshing the page.';
118
+ showIt = true;
119
+ bgColor = 'bg-red-50 dark:bg-red-950/20';
120
+ textColor = 'text-red-900 dark:text-red-100';
121
+ iconColor = 'text-red-600 dark:text-red-400';
122
+ }
123
+ if (status == 'success' && imgList.length === 0) {
124
+ icon = <LuImageOff className="w-8 h-8" />;
125
+ text = 'No Images Found';
126
+ subtitle = 'This dataset is empty. Click "Add Images" to get started.';
127
+ showIt = true;
128
+ bgColor = 'bg-gray-50 dark:bg-gray-800/50';
129
+ textColor = 'text-gray-900 dark:text-gray-100';
130
+ iconColor = 'text-gray-500 dark:text-gray-400';
131
+ }
132
+
133
+ if (!showIt) return null;
134
+
135
+ return (
136
+ <div
137
+ className={`mt-10 flex flex-col items-center justify-center py-16 px-8 rounded-xl border-2 border-gray-700 border-dashed ${bgColor} ${textColor} mx-auto max-w-md text-center`}
138
+ >
139
+ <div className={`${iconColor} mb-4`}>{icon}</div>
140
+ <h3 className="text-lg font-semibold mb-2">{text}</h3>
141
+ <p className="text-sm opacity-75 leading-relaxed">{subtitle}</p>
142
+ </div>
143
+ );
144
+ }, [status, imgList.length]);
145
+
146
+ return (
147
+ <>
148
+ {/* Fixed top bar */}
149
+ <TopBar>
150
+ <div>
151
+ <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
152
+ <FaChevronLeft />
153
+ </Button>
154
+ </div>
155
+ <div>
156
+ <h1 className="text-lg">Dataset: {datasetName}</h1>
157
+ </div>
158
+ <div className="flex-1"></div>
159
+ <div>
160
+ <Button
161
+ className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
162
+ onClick={() => openImagesModal(datasetName, () => refreshImageList(datasetName))}
163
+ >
164
+ Add Images
165
+ </Button>
166
+ </div>
167
+ </TopBar>
168
+ <MainContent>
169
+ {PageInfoContent}
170
+ {status === 'success' && imgList.length > 0 && (
171
+ <div className="grid grid-cols-1 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-4">
172
+ {imgList.map(img => (
173
+ <DatasetImageCard
174
+ key={img.img_path}
175
+ alt="image"
176
+ imageUrl={img.img_path}
177
+ onDelete={() => refreshImageList(datasetName)}
178
+ />
179
+ ))}
180
+ </div>
181
+ )}
182
+ </MainContent>
183
+ <AddImagesModal />
184
+ <FullscreenDropOverlay
185
+ datasetName={datasetName}
186
+ onComplete={() => refreshImageList(datasetName)}
187
+ />
188
+ </>
189
+ );
190
+ }
src/app/datasets/page.tsx ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useState } from 'react';
4
+ import { Modal } from '@/components/Modal';
5
+ import Link from 'next/link';
6
+ import { TextInput } from '@/components/formInputs';
7
+ import useDatasetList from '@/hooks/useDatasetList';
8
+ import { Button } from '@headlessui/react';
9
+ import { FaRegTrashAlt } from 'react-icons/fa';
10
+ import { openConfirm } from '@/components/ConfirmModal';
11
+ import { TopBar, MainContent } from '@/components/layout';
12
+ import UniversalTable, { TableColumn } from '@/components/UniversalTable';
13
+ import { apiClient } from '@/utils/api';
14
+ import { useRouter } from 'next/navigation';
15
+ import { usingBrowserDb } from '@/utils/env';
16
+ import { addUserDataset, removeUserDataset } from '@/utils/storage/datasetStorage';
17
+ import { useAuth } from '@/contexts/AuthContext';
18
+ import HFLoginButton from '@/components/HFLoginButton';
19
+
20
+ export default function Datasets() {
21
+ const router = useRouter();
22
+ const { datasets, status, refreshDatasets } = useDatasetList();
23
+ const [newDatasetName, setNewDatasetName] = useState('');
24
+ const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false);
25
+ const { status: authStatus } = useAuth();
26
+ const isAuthenticated = authStatus === 'authenticated';
27
+
28
+ // Transform datasets array into rows with objects
29
+ const tableRows = datasets.map(dataset => ({
30
+ name: dataset,
31
+ actions: dataset, // Pass full dataset name for actions
32
+ }));
33
+
34
+ const columns: TableColumn[] = [
35
+ {
36
+ title: 'Dataset Name',
37
+ key: 'name',
38
+ render: row => (
39
+ <Link href={`/datasets/${row.name}`} className="text-gray-200 hover:text-gray-100">
40
+ {row.name}
41
+ </Link>
42
+ ),
43
+ },
44
+ {
45
+ title: 'Actions',
46
+ key: 'actions',
47
+ className: 'w-20 text-right',
48
+ render: row => (
49
+ <button
50
+ className="text-gray-200 hover:bg-red-600 p-2 rounded-full transition-colors"
51
+ onClick={() => handleDeleteDataset(row.name)}
52
+ >
53
+ <FaRegTrashAlt />
54
+ </button>
55
+ ),
56
+ },
57
+ ];
58
+
59
+ const handleDeleteDataset = (datasetName: string) => {
60
+ openConfirm({
61
+ title: 'Delete Dataset',
62
+ message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`,
63
+ type: 'warning',
64
+ confirmText: 'Delete',
65
+ onConfirm: () => {
66
+ apiClient
67
+ .post('/api/datasets/delete', { name: datasetName })
68
+ .then(() => {
69
+ console.log('Dataset deleted:', datasetName);
70
+ if (usingBrowserDb) {
71
+ removeUserDataset(datasetName);
72
+ }
73
+ refreshDatasets();
74
+ })
75
+ .catch(error => {
76
+ console.error('Error deleting dataset:', error);
77
+ });
78
+ },
79
+ });
80
+ };
81
+
82
+ const handleCreateDataset = async (e: React.FormEvent) => {
83
+ e.preventDefault();
84
+ if (!isAuthenticated) {
85
+ return;
86
+ }
87
+ try {
88
+ const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data);
89
+ console.log('New dataset created:', data);
90
+ if (usingBrowserDb && data?.name) {
91
+ addUserDataset(data.name, data?.path || '');
92
+ }
93
+ refreshDatasets();
94
+ setNewDatasetName('');
95
+ setIsNewDatasetModalOpen(false);
96
+ } catch (error) {
97
+ console.error('Error creating new dataset:', error);
98
+ }
99
+ };
100
+
101
+ const openNewDatasetModal = () => {
102
+ if (!isAuthenticated) {
103
+ return;
104
+ }
105
+ openConfirm({
106
+ title: 'New Dataset',
107
+ message: 'Enter the name of the new dataset:',
108
+ type: 'info',
109
+ confirmText: 'Create',
110
+ inputTitle: 'Dataset Name',
111
+ onConfirm: async (name?: string) => {
112
+ if (!name) {
113
+ console.error('Dataset name is required.');
114
+ return;
115
+ }
116
+ if (!isAuthenticated) {
117
+ return;
118
+ }
119
+ try {
120
+ const data = await apiClient.post('/api/datasets/create', { name }).then(res => res.data);
121
+ console.log('New dataset created:', data);
122
+ if (usingBrowserDb && data?.name) {
123
+ addUserDataset(data.name, data?.path || '');
124
+ }
125
+ if (data.name) {
126
+ router.push(`/datasets/${data.name}`);
127
+ } else {
128
+ refreshDatasets();
129
+ }
130
+ } catch (error) {
131
+ console.error('Error creating new dataset:', error);
132
+ }
133
+ },
134
+ });
135
+ };
136
+
137
+ return (
138
+ <>
139
+ <TopBar>
140
+ <div>
141
+ <h1 className="text-2xl font-semibold text-gray-100">Datasets</h1>
142
+ </div>
143
+ <div className="flex-1"></div>
144
+ <div>
145
+ {isAuthenticated ? (
146
+ <Button
147
+ className="text-gray-200 bg-slate-600 px-4 py-2 rounded-md hover:bg-slate-500 transition-colors"
148
+ onClick={() => openNewDatasetModal()}
149
+ >
150
+ New Dataset
151
+ </Button>
152
+ ) : (
153
+ <span className="text-gray-600 bg-gray-900 px-3 py-1 rounded-md border border-gray-800">
154
+ Sign in to add datasets
155
+ </span>
156
+ )}
157
+ </div>
158
+ </TopBar>
159
+
160
+ <MainContent>
161
+ {isAuthenticated ? (
162
+ <UniversalTable
163
+ columns={columns}
164
+ rows={tableRows}
165
+ isLoading={status === 'loading'}
166
+ onRefresh={refreshDatasets}
167
+ />
168
+ ) : (
169
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
170
+ <p>Sign in with Hugging Face or add an access token to manage datasets.</p>
171
+ <div className="flex items-center gap-3">
172
+ <HFLoginButton size="sm" />
173
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
174
+ Manage authentication in Settings
175
+ </Link>
176
+ </div>
177
+ </div>
178
+ )}
179
+ </MainContent>
180
+
181
+ <Modal
182
+ isOpen={isNewDatasetModalOpen}
183
+ onClose={() => setIsNewDatasetModalOpen(false)}
184
+ title="New Dataset"
185
+ size="md"
186
+ >
187
+ <div className="space-y-4 text-gray-200">
188
+ <form onSubmit={handleCreateDataset}>
189
+ <div className="text-sm text-gray-400">
190
+ This will create a new folder with the name below in your dataset folder.
191
+ </div>
192
+ <div className="mt-4">
193
+ <TextInput label="Dataset Name" value={newDatasetName} onChange={value => setNewDatasetName(value)} />
194
+ </div>
195
+
196
+ <div className="mt-6 flex justify-end space-x-3">
197
+ <button
198
+ type="button"
199
+ className="rounded-md bg-gray-700 px-4 py-2 text-gray-200 hover:bg-gray-600 focus:outline-none focus:ring-2 focus:ring-gray-500"
200
+ onClick={() => setIsNewDatasetModalOpen(false)}
201
+ >
202
+ Cancel
203
+ </button>
204
+ <button
205
+ type="submit"
206
+ className="rounded-md bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 disabled:opacity-50 disabled:cursor-not-allowed"
207
+ disabled={!isAuthenticated}
208
+ >
209
+ Confirm
210
+ </button>
211
+ </div>
212
+ </form>
213
+ </div>
214
+ </Modal>
215
+ </>
216
+ );
217
+ }
src/app/favicon.ico ADDED
src/app/globals.css ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
4
+
5
+ :root {
6
+ --background: #ffffff;
7
+ --foreground: #171717;
8
+ }
9
+
10
+ @media (prefers-color-scheme: dark) {
11
+ :root {
12
+ --background: #0a0a0a;
13
+ --foreground: #ededed;
14
+ }
15
+ }
16
+
17
+ body {
18
+ color: var(--foreground);
19
+ background: var(--background);
20
+ font-family: Arial, Helvetica, sans-serif;
21
+ }
22
+
23
+ @layer components {
24
+ /* control */
25
+ .aitk-react-select-container .aitk-react-select__control {
26
+ @apply flex w-full h-8 min-h-0 px-0 text-sm bg-gray-800 border border-gray-700 rounded-sm hover:border-gray-600 items-center;
27
+ }
28
+
29
+ /* selected label */
30
+ .aitk-react-select-container .aitk-react-select__single-value {
31
+ @apply flex-1 min-w-0 truncate text-sm text-neutral-200;
32
+ }
33
+
34
+ /* invisible input (keeps focus & typing, never wraps) */
35
+ .aitk-react-select-container .aitk-react-select__input-container {
36
+ @apply text-neutral-200;
37
+ }
38
+
39
+ /* focus */
40
+ .aitk-react-select-container .aitk-react-select__control--is-focused {
41
+ @apply ring-2 ring-gray-600 border-transparent hover:border-transparent shadow-none;
42
+ }
43
+
44
+ /* menu */
45
+ .aitk-react-select-container .aitk-react-select__menu {
46
+ @apply bg-gray-800 border border-gray-700;
47
+ }
48
+
49
+ /* options */
50
+ .aitk-react-select-container .aitk-react-select__option {
51
+ @apply text-sm text-neutral-200 bg-gray-800 hover:bg-gray-700;
52
+ }
53
+
54
+ /* indicator separator */
55
+ .aitk-react-select-container .aitk-react-select__indicator-separator {
56
+ @apply bg-gray-600;
57
+ }
58
+
59
+ /* indicators */
60
+ .aitk-react-select-container .aitk-react-select__indicators,
61
+ .aitk-react-select-container .aitk-react-select__indicator {
62
+ @apply py-0 flex items-center;
63
+ }
64
+
65
+ /* placeholder */
66
+ .aitk-react-select-container .aitk-react-select__placeholder {
67
+ @apply text-sm text-neutral-200;
68
+ }
69
+ }
70
+
71
+
72
+
src/app/icon.png ADDED
src/app/icon.svg ADDED
src/app/jobs/[jobID]/page.tsx ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useState, use } from 'react';
4
+ import { FaChevronLeft } from 'react-icons/fa';
5
+ import { Button } from '@headlessui/react';
6
+ import { TopBar, MainContent } from '@/components/layout';
7
+ import useJob from '@/hooks/useJob';
8
+ import SampleImages, {SampleImagesMenu} from '@/components/SampleImages';
9
+ import JobOverview from '@/components/JobOverview';
10
+ import { redirect } from 'next/navigation';
11
+ import { useAuth } from '@/contexts/AuthContext';
12
+ import HFLoginButton from '@/components/HFLoginButton';
13
+ import Link from 'next/link';
14
+ import JobActionBar from '@/components/JobActionBar';
15
+ import JobConfigViewer from '@/components/JobConfigViewer';
16
+ import { JobRecord } from '@/types';
17
+
18
+ type PageKey = 'overview' | 'samples' | 'config';
19
+
20
+ interface Page {
21
+ name: string;
22
+ value: PageKey;
23
+ component: React.ComponentType<{ job: JobRecord }>;
24
+ menuItem?: React.ComponentType<{ job?: JobRecord | null }> | null;
25
+ mainCss?: string;
26
+ }
27
+
28
+ const pages: Page[] = [
29
+ {
30
+ name: 'Overview',
31
+ value: 'overview',
32
+ component: JobOverview,
33
+ mainCss: 'pt-24',
34
+ },
35
+ {
36
+ name: 'Samples',
37
+ value: 'samples',
38
+ component: SampleImages,
39
+ menuItem: SampleImagesMenu,
40
+ mainCss: 'pt-24',
41
+ },
42
+ {
43
+ name: 'Config File',
44
+ value: 'config',
45
+ component: JobConfigViewer,
46
+ mainCss: 'pt-[80px] px-0 pb-0',
47
+ },
48
+ ];
49
+
50
+ export default function JobPage({ params }: { params: { jobID: string } }) {
51
+ const usableParams = use(params as any) as { jobID: string };
52
+ const jobID = usableParams.jobID;
53
+ const { job, status, refreshJob } = useJob(jobID, 5000);
54
+ const [pageKey, setPageKey] = useState<PageKey>('overview');
55
+ const { status: authStatus } = useAuth();
56
+ const isAuthenticated = authStatus === 'authenticated';
57
+
58
+ const page = pages.find(p => p.value === pageKey);
59
+
60
+ if (!isAuthenticated) {
61
+ return (
62
+ <>
63
+ <TopBar>
64
+ <div>
65
+ <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => redirect('/jobs')}>
66
+ <FaChevronLeft />
67
+ </Button>
68
+ </div>
69
+ <div>
70
+ <h1 className="text-lg">Job Details</h1>
71
+ </div>
72
+ <div className="flex-1"></div>
73
+ </TopBar>
74
+ <MainContent>
75
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
76
+ <p>Sign in with Hugging Face or add an access token to view job details.</p>
77
+ <div className="flex items-center gap-3">
78
+ <HFLoginButton size="sm" />
79
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
80
+ Manage authentication in Settings
81
+ </Link>
82
+ </div>
83
+ </div>
84
+ </MainContent>
85
+ </>
86
+ );
87
+ }
88
+
89
+ return (
90
+ <>
91
+ {/* Fixed top bar */}
92
+ <TopBar>
93
+ <div>
94
+ <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => redirect('/jobs')}>
95
+ <FaChevronLeft />
96
+ </Button>
97
+ </div>
98
+ <div>
99
+ <h1 className="text-lg">Job: {job?.name}</h1>
100
+ </div>
101
+ <div className="flex-1"></div>
102
+ {job && (
103
+ <JobActionBar
104
+ job={job}
105
+ onRefresh={refreshJob}
106
+ hideView
107
+ afterDelete={() => {
108
+ redirect('/jobs');
109
+ }}
110
+ />
111
+ )}
112
+ </TopBar>
113
+ <MainContent className={pages.find(page => page.value === pageKey)?.mainCss}>
114
+ {status === 'loading' && job == null && <p>Loading...</p>}
115
+ {status === 'error' && job == null && <p>Error fetching job</p>}
116
+ {job && (
117
+ <>
118
+ {pages.map(page => {
119
+ const Component = page.component;
120
+ return page.value === pageKey ? <Component key={page.value} job={job} /> : null;
121
+ })}
122
+ </>
123
+ )}
124
+ </MainContent>
125
+ <div className="bg-gray-800 absolute top-12 left-0 w-full h-8 flex items-center px-2 text-sm">
126
+ {pages.map(page => (
127
+ <Button
128
+ key={page.value}
129
+ onClick={() => setPageKey(page.value)}
130
+ className={`px-4 py-1 h-8 ${page.value === pageKey ? 'bg-gray-300 dark:bg-gray-700' : ''}`}
131
+ >
132
+ {page.name}
133
+ </Button>
134
+ ))}
135
+ {
136
+ page?.menuItem && (
137
+ <>
138
+ <div className='flex-grow'>
139
+ </div>
140
+ <page.menuItem job={job} />
141
+ </>
142
+ )
143
+ }
144
+ </div>
145
+ </>
146
+ );
147
+ }
src/app/jobs/new/AdvancedJob.tsx ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+ import { useEffect, useState, useRef } from 'react';
3
+ import { JobConfig } from '@/types';
4
+ import YAML from 'yaml';
5
+ import Editor, { OnMount } from '@monaco-editor/react';
6
+ import type { editor } from 'monaco-editor';
7
+ import { SettingsData } from '@/types';
8
+ import { migrateJobConfig } from './jobConfig';
9
+
10
+ type Props = {
11
+ jobConfig: JobConfig;
12
+ setJobConfig: (value: any, key?: string) => void;
13
+ status: 'idle' | 'saving' | 'success' | 'error';
14
+ handleSubmit: (event: React.FormEvent<HTMLFormElement>) => void;
15
+ runId: string | null;
16
+ gpuIDs: string | null;
17
+ setGpuIDs: (value: string | null) => void;
18
+ gpuList: any;
19
+ datasetOptions: any;
20
+ settings: SettingsData;
21
+ };
22
+
23
+ const isDev = process.env.NODE_ENV === 'development';
24
+
25
+ const yamlConfig: YAML.DocumentOptions &
26
+ YAML.SchemaOptions &
27
+ YAML.ParseOptions &
28
+ YAML.CreateNodeOptions &
29
+ YAML.ToStringOptions = {
30
+ indent: 2,
31
+ lineWidth: 999999999999,
32
+ defaultStringType: 'QUOTE_DOUBLE',
33
+ defaultKeyType: 'PLAIN',
34
+ directives: true,
35
+ };
36
+
37
+ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) {
38
+ const [editorValue, setEditorValue] = useState<string>('');
39
+ const lastJobConfigUpdateStringRef = useRef('');
40
+ const editorRef = useRef<editor.IStandaloneCodeEditor | null>(null);
41
+
42
+ // Track if the editor has been mounted
43
+ const isEditorMounted = useRef(false);
44
+
45
+ // Handler for editor mounting
46
+ const handleEditorDidMount: OnMount = editor => {
47
+ editorRef.current = editor;
48
+ isEditorMounted.current = true;
49
+
50
+ // Initial content setup
51
+ try {
52
+ const yamlContent = YAML.stringify(jobConfig, yamlConfig);
53
+ setEditorValue(yamlContent);
54
+ lastJobConfigUpdateStringRef.current = JSON.stringify(jobConfig);
55
+ } catch (e) {
56
+ console.warn(e);
57
+ }
58
+ };
59
+
60
+ useEffect(() => {
61
+ const lastUpdate = lastJobConfigUpdateStringRef.current;
62
+ const currentUpdate = JSON.stringify(jobConfig);
63
+
64
+ // Skip if no changes or editor not yet mounted
65
+ if (lastUpdate === currentUpdate || !isEditorMounted.current) {
66
+ return;
67
+ }
68
+
69
+ try {
70
+ // Preserve cursor position and selection
71
+ const editor = editorRef.current;
72
+ if (editor) {
73
+ // Save current editor state
74
+ const position = editor.getPosition();
75
+ const selection = editor.getSelection();
76
+ const scrollTop = editor.getScrollTop();
77
+
78
+ // Update content
79
+ const yamlContent = YAML.stringify(jobConfig, yamlConfig);
80
+
81
+ // Only update if the content is actually different
82
+ if (yamlContent !== editor.getValue()) {
83
+ // Set value directly on the editor model instead of using React state
84
+ editor.getModel()?.setValue(yamlContent);
85
+
86
+ // Restore cursor position and selection
87
+ if (position) editor.setPosition(position);
88
+ if (selection) editor.setSelection(selection);
89
+ editor.setScrollTop(scrollTop);
90
+ }
91
+
92
+ lastJobConfigUpdateStringRef.current = currentUpdate;
93
+ }
94
+ } catch (e) {
95
+ console.warn(e);
96
+ }
97
+ }, [jobConfig]);
98
+
99
+ const handleChange = (value: string | undefined) => {
100
+ if (value === undefined) return;
101
+
102
+ try {
103
+ const parsed = YAML.parse(value);
104
+ // Don't update jobConfig if the change came from the editor itself
105
+ // to avoid a circular update loop
106
+ if (JSON.stringify(parsed) !== lastJobConfigUpdateStringRef.current) {
107
+ lastJobConfigUpdateStringRef.current = JSON.stringify(parsed);
108
+
109
+ // We have to ensure certain things are always set
110
+ try {
111
+ parsed.config.process[0].type = 'ui_trainer';
112
+ parsed.config.process[0].sqlite_db_path = './aitk_db.db';
113
+ parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
114
+ parsed.config.process[0].device = 'cuda';
115
+ parsed.config.process[0].performance_log_every = 10;
116
+ } catch (e) {
117
+ console.warn(e);
118
+ }
119
+ migrateJobConfig(parsed);
120
+ setJobConfig(parsed);
121
+ }
122
+ } catch (e) {
123
+ // Don't update on parsing errors
124
+ console.warn(e);
125
+ }
126
+ };
127
+
128
+ return (
129
+ <>
130
+ <Editor
131
+ height="100%"
132
+ width="100%"
133
+ defaultLanguage="yaml"
134
+ value={editorValue}
135
+ theme="vs-dark"
136
+ onChange={handleChange}
137
+ onMount={handleEditorDidMount}
138
+ options={{
139
+ minimap: { enabled: true },
140
+ scrollBeyondLastLine: false,
141
+ automaticLayout: true,
142
+ }}
143
+ />
144
+ </>
145
+ );
146
+ }
src/app/jobs/new/SimpleJob.tsx ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+ import { useMemo, useState } from 'react';
3
+ import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
4
+ import { defaultDatasetConfig } from './jobConfig';
5
+ import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
6
+ import { objectCopy } from '@/utils/basic';
7
+ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
8
+ import Card from '@/components/Card';
9
+ import { X } from 'lucide-react';
10
+ import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
11
+ import {FlipHorizontal2, FlipVertical2} from "lucide-react";
12
+ import HFJobsWorkflow from '@/components/HFJobsWorkflow';
13
+
14
+ type Props = {
15
+ jobConfig: JobConfig;
16
+ setJobConfig: (value: any, key: string) => void;
17
+ status: 'idle' | 'saving' | 'success' | 'error';
18
+ handleSubmit: (event: React.FormEvent<HTMLFormElement>) => void;
19
+ runId: string | null;
20
+ gpuIDs: string | null;
21
+ setGpuIDs: (value: string | null) => void;
22
+ gpuList: any;
23
+ datasetOptions: any;
24
+ trainingBackend?: 'local' | 'hf-jobs';
25
+ setTrainingBackend?: (backend: 'local' | 'hf-jobs') => void;
26
+ hfJobSubmitted?: boolean;
27
+ onHFJobComplete?: (jobId: string, localJobId?: string) => void;
28
+ forceHFBackend?: boolean;
29
+ };
30
+
31
+ const isDev = process.env.NODE_ENV === 'development';
32
+
33
+ export default function SimpleJob({
34
+ jobConfig,
35
+ setJobConfig,
36
+ handleSubmit,
37
+ status,
38
+ runId,
39
+ gpuIDs,
40
+ setGpuIDs,
41
+ gpuList,
42
+ datasetOptions,
43
+ trainingBackend: parentTrainingBackend,
44
+ setTrainingBackend: parentSetTrainingBackend,
45
+ hfJobSubmitted,
46
+ onHFJobComplete,
47
+ forceHFBackend = false,
48
+ }: Props) {
49
+ const [localTrainingBackend, setLocalTrainingBackend] = useState(forceHFBackend ? 'hf-jobs' : 'local');
50
+ const trainingBackend = parentTrainingBackend || localTrainingBackend;
51
+ const setTrainingBackend = forceHFBackend
52
+ ? (_: 'local' | 'hf-jobs') => undefined
53
+ : parentSetTrainingBackend || setLocalTrainingBackend;
54
+ const backendOptions = forceHFBackend
55
+ ? [{ value: 'hf-jobs', label: 'HF Jobs (Cloud)' }]
56
+ : [
57
+ { value: 'local', label: 'Local GPU' },
58
+ { value: 'hf-jobs', label: 'HF Jobs (Cloud)' },
59
+ ];
60
+ const modelArch = useMemo(() => {
61
+ return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
62
+ }, [jobConfig.config.process[0].model.arch]);
63
+
64
+ const isVideoModel = !!(modelArch?.group === 'video');
65
+
66
+ const numTopCards = useMemo(() => {
67
+ let count = 4; // job settings, model config, target config, save config
68
+ if (modelArch?.additionalSections?.includes('model.multistage')) {
69
+ count += 1; // add multistage card
70
+ }
71
+ if (!modelArch?.disableSections?.includes('model.quantize')) {
72
+ count += 1; // add quantization card
73
+ }
74
+ return count;
75
+
76
+ }, [modelArch]);
77
+
78
+ let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
79
+
80
+ if (numTopCards == 5) {
81
+ topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
82
+ }
83
+ if (numTopCards == 6) {
84
+ topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
85
+ }
86
+
87
+ const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
88
+ const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
89
+ if (!hasARA) {
90
+ return quantizationOptions;
91
+ }
92
+ let newQuantizationOptions = [
93
+ {
94
+ label: 'Standard',
95
+ options: [quantizationOptions[0], quantizationOptions[1]],
96
+ },
97
+ ];
98
+
99
+ // add ARAs if they exist for the model
100
+ let ARAs: SelectOption[] = [];
101
+ if (modelArch.accuracyRecoveryAdapters) {
102
+ for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
103
+ ARAs.push({ value, label });
104
+ }
105
+ }
106
+ if (ARAs.length > 0) {
107
+ newQuantizationOptions.push({
108
+ label: 'Accuracy Recovery Adapters',
109
+ options: ARAs,
110
+ });
111
+ }
112
+
113
+ let additionalQuantizationOptions: SelectOption[] = [];
114
+ // add the quantization options if they are not already included
115
+ for (let i = 2; i < quantizationOptions.length; i++) {
116
+ const option = quantizationOptions[i];
117
+ additionalQuantizationOptions.push(option);
118
+ }
119
+ if (additionalQuantizationOptions.length > 0) {
120
+ newQuantizationOptions.push({
121
+ label: 'Additional Quantization Options',
122
+ options: additionalQuantizationOptions,
123
+ });
124
+ }
125
+ return newQuantizationOptions;
126
+ }, [modelArch]);
127
+
128
+ return (
129
+ <>
130
+ <form onSubmit={handleSubmit} className="space-y-8">
131
+ <div className={topBarClass}>
132
+ <Card title="Job">
133
+ <TextInput
134
+ label="Training Name"
135
+ value={jobConfig.config.name}
136
+ docKey="config.name"
137
+ onChange={value => setJobConfig(value, 'config.name')}
138
+ placeholder="Enter training name"
139
+ disabled={runId !== null}
140
+ required
141
+ />
142
+ <SelectInput
143
+ label="Training Backend"
144
+ value={trainingBackend}
145
+ onChange={(value) => {
146
+ setTrainingBackend(value);
147
+ }}
148
+ options={backendOptions}
149
+ disabled={forceHFBackend}
150
+ />
151
+ {trainingBackend === 'local' && (
152
+ <SelectInput
153
+ label="GPU ID"
154
+ value={`${gpuIDs}`}
155
+ docKey="gpuids"
156
+ onChange={value => setGpuIDs(value)}
157
+ options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
158
+ />
159
+ )}
160
+ <TextInput
161
+ label="Trigger Word"
162
+ value={jobConfig.config.process[0].trigger_word || ''}
163
+ docKey="config.process[0].trigger_word"
164
+ onChange={(value: string | null) => {
165
+ if (value?.trim() === '') {
166
+ value = null;
167
+ }
168
+ setJobConfig(value, 'config.process[0].trigger_word');
169
+ }}
170
+ placeholder=""
171
+ required
172
+ />
173
+ {trainingBackend === 'hf-jobs' && (
174
+ <div className={`mt-4 p-3 rounded ${
175
+ hfJobSubmitted
176
+ ? 'bg-green-900/20 border border-green-700'
177
+ : 'bg-yellow-900/20 border border-yellow-700'
178
+ }`}>
179
+ <p className={`text-sm ${
180
+ hfJobSubmitted ? 'text-green-400' : 'text-yellow-400'
181
+ }`}>
182
+ {hfJobSubmitted
183
+ ? '✓ HF Job already submitted! You can modify settings and resubmit if needed.'
184
+ : '⏳ HF Job ready for submission. Submit to the cloud below.'
185
+ }
186
+ </p>
187
+ </div>
188
+ )}
189
+ </Card>
190
+
191
+ {/* Model Configuration Section */}
192
+ <Card title="Model">
193
+ <SelectInput
194
+ label="Model Architecture"
195
+ value={jobConfig.config.process[0].model.arch}
196
+ onChange={value => {
197
+ const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch);
198
+ if (!currentArch || currentArch.name === value) {
199
+ return;
200
+ }
201
+ // update the defaults when a model is selected
202
+ const newArch = modelArchs.find(model => model.name === value);
203
+
204
+ // update vram setting
205
+ if (!newArch?.additionalSections?.includes('model.low_vram')) {
206
+ setJobConfig(false, 'config.process[0].model.low_vram');
207
+ }
208
+
209
+ // revert defaults from previous model
210
+ for (const key in currentArch.defaults) {
211
+ setJobConfig(currentArch.defaults[key][1], key);
212
+ }
213
+
214
+ if (newArch?.defaults) {
215
+ for (const key in newArch.defaults) {
216
+ setJobConfig(newArch.defaults[key][0], key);
217
+ }
218
+ }
219
+ // set new model
220
+ setJobConfig(value, 'config.process[0].model.arch');
221
+
222
+ // update datasets
223
+ const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
224
+ const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
225
+ const controls = newArch?.controls ?? [];
226
+ const datasets = jobConfig.config.process[0].datasets.map(dataset => {
227
+ const newDataset = objectCopy(dataset);
228
+ newDataset.controls = controls;
229
+ if (!hasControlPath) {
230
+ newDataset.control_path = null; // reset control path if not applicable
231
+ }
232
+ if (!hasNumFrames) {
233
+ newDataset.num_frames = 1; // reset num_frames if not applicable
234
+ }
235
+ return newDataset;
236
+ });
237
+ setJobConfig(datasets, 'config.process[0].datasets');
238
+
239
+ // update samples
240
+ const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false;
241
+ const samples = jobConfig.config.process[0].sample.samples.map(sample => {
242
+ const newSample = objectCopy(sample);
243
+ if (!hasSampleCtrlImg) {
244
+ delete newSample.ctrl_img; // remove ctrl_img if not applicable
245
+ }
246
+ return newSample;
247
+ });
248
+ setJobConfig(samples, 'config.process[0].sample.samples');
249
+ }}
250
+ options={groupedModelOptions}
251
+ />
252
+ <TextInput
253
+ label="Name or Path"
254
+ value={jobConfig.config.process[0].model.name_or_path}
255
+ docKey="config.process[0].model.name_or_path"
256
+ onChange={(value: string | null) => {
257
+ if (value?.trim() === '') {
258
+ value = null;
259
+ }
260
+ setJobConfig(value, 'config.process[0].model.name_or_path');
261
+ }}
262
+ placeholder=""
263
+ required
264
+ />
265
+ {modelArch?.additionalSections?.includes('model.low_vram') && (
266
+ <FormGroup label="Options">
267
+ <Checkbox
268
+ label="Low VRAM"
269
+ checked={jobConfig.config.process[0].model.low_vram}
270
+ onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')}
271
+ />
272
+ </FormGroup>
273
+ )}
274
+ </Card>
275
+ {modelArch?.disableSections?.includes('model.quantize') ? null : (
276
+ <Card title="Quantization">
277
+ <SelectInput
278
+ label="Transformer"
279
+ value={jobConfig.config.process[0].model.quantize ? jobConfig.config.process[0].model.qtype : ''}
280
+ onChange={value => {
281
+ if (value === '') {
282
+ setJobConfig(false, 'config.process[0].model.quantize');
283
+ value = defaultQtype;
284
+ } else {
285
+ setJobConfig(true, 'config.process[0].model.quantize');
286
+ }
287
+ setJobConfig(value, 'config.process[0].model.qtype');
288
+ }}
289
+ options={transformerQuantizationOptions}
290
+ />
291
+ <SelectInput
292
+ label="Text Encoder"
293
+ value={jobConfig.config.process[0].model.quantize_te ? jobConfig.config.process[0].model.qtype_te : ''}
294
+ onChange={value => {
295
+ if (value === '') {
296
+ setJobConfig(false, 'config.process[0].model.quantize_te');
297
+ value = defaultQtype;
298
+ } else {
299
+ setJobConfig(true, 'config.process[0].model.quantize_te');
300
+ }
301
+ setJobConfig(value, 'config.process[0].model.qtype_te');
302
+ }}
303
+ options={quantizationOptions}
304
+ />
305
+ </Card>
306
+ )}
307
+ {modelArch?.additionalSections?.includes('model.multistage') && (
308
+ <Card title="Multistage">
309
+ <FormGroup label="Stages to Train" docKey={'model.multistage'}>
310
+ <Checkbox
311
+ label="High Noise"
312
+ checked={jobConfig.config.process[0].model.model_kwargs?.train_high_noise || false}
313
+ onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')}
314
+ />
315
+ <Checkbox
316
+ label="Low Noise"
317
+ checked={jobConfig.config.process[0].model.model_kwargs?.train_low_noise || false}
318
+ onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')}
319
+ />
320
+ </FormGroup>
321
+ <NumberInput
322
+ label="Switch Every"
323
+ value={jobConfig.config.process[0].train.switch_boundary_every}
324
+ onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
325
+ placeholder="eg. 1"
326
+ docKey={'train.switch_boundary_every'}
327
+ min={1}
328
+ required
329
+ />
330
+ </Card>
331
+ )}
332
+ <Card title="Target">
333
+ <SelectInput
334
+ label="Target Type"
335
+ value={jobConfig.config.process[0].network?.type ?? 'lora'}
336
+ onChange={value => setJobConfig(value, 'config.process[0].network.type')}
337
+ options={[
338
+ { value: 'lora', label: 'LoRA' },
339
+ { value: 'lokr', label: 'LoKr' },
340
+ ]}
341
+ />
342
+ {jobConfig.config.process[0].network?.type == 'lokr' && (
343
+ <SelectInput
344
+ label="LoKr Factor"
345
+ value={`${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
346
+ onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
347
+ options={[
348
+ { value: '-1', label: 'Auto' },
349
+ { value: '4', label: '4' },
350
+ { value: '8', label: '8' },
351
+ { value: '16', label: '16' },
352
+ { value: '32', label: '32' },
353
+ ]}
354
+ />
355
+ )}
356
+ {jobConfig.config.process[0].network?.type == 'lora' && (
357
+ <>
358
+ <NumberInput
359
+ label="Linear Rank"
360
+ value={jobConfig.config.process[0].network.linear}
361
+ onChange={value => {
362
+ console.log('onChange', value);
363
+ setJobConfig(value, 'config.process[0].network.linear');
364
+ setJobConfig(value, 'config.process[0].network.linear_alpha');
365
+ }}
366
+ placeholder="eg. 16"
367
+ min={0}
368
+ max={1024}
369
+ required
370
+ />
371
+ {modelArch?.disableSections?.includes('network.conv') ? null : (
372
+ <NumberInput
373
+ label="Conv Rank"
374
+ value={jobConfig.config.process[0].network.conv}
375
+ onChange={value => {
376
+ console.log('onChange', value);
377
+ setJobConfig(value, 'config.process[0].network.conv');
378
+ setJobConfig(value, 'config.process[0].network.conv_alpha');
379
+ }}
380
+ placeholder="eg. 16"
381
+ min={0}
382
+ max={1024}
383
+ />
384
+ )}
385
+ </>
386
+ )}
387
+ </Card>
388
+ <Card title="Save">
389
+ <SelectInput
390
+ label="Data Type"
391
+ value={jobConfig.config.process[0].save.dtype}
392
+ onChange={value => setJobConfig(value, 'config.process[0].save.dtype')}
393
+ options={[
394
+ { value: 'bf16', label: 'BF16' },
395
+ { value: 'fp16', label: 'FP16' },
396
+ { value: 'fp32', label: 'FP32' },
397
+ ]}
398
+ />
399
+ <NumberInput
400
+ label="Save Every"
401
+ value={jobConfig.config.process[0].save.save_every}
402
+ onChange={value => setJobConfig(value, 'config.process[0].save.save_every')}
403
+ placeholder="eg. 250"
404
+ min={1}
405
+ required
406
+ />
407
+ <NumberInput
408
+ label="Max Step Saves to Keep"
409
+ value={jobConfig.config.process[0].save.max_step_saves_to_keep}
410
+ onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')}
411
+ placeholder="eg. 4"
412
+ min={1}
413
+ required
414
+ />
415
+ </Card>
416
+ </div>
417
+ <div>
418
+ <Card title="Training">
419
+ <div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
420
+ <div>
421
+ <NumberInput
422
+ label="Batch Size"
423
+ value={jobConfig.config.process[0].train.batch_size}
424
+ onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
425
+ placeholder="eg. 4"
426
+ min={1}
427
+ required
428
+ />
429
+ <NumberInput
430
+ label="Gradient Accumulation"
431
+ className="pt-2"
432
+ value={jobConfig.config.process[0].train.gradient_accumulation}
433
+ onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')}
434
+ placeholder="eg. 1"
435
+ min={1}
436
+ required
437
+ />
438
+ <NumberInput
439
+ label="Steps"
440
+ className="pt-2"
441
+ value={jobConfig.config.process[0].train.steps}
442
+ onChange={value => setJobConfig(value, 'config.process[0].train.steps')}
443
+ placeholder="eg. 2000"
444
+ min={1}
445
+ required
446
+ />
447
+ </div>
448
+ <div>
449
+ <SelectInput
450
+ label="Optimizer"
451
+ value={jobConfig.config.process[0].train.optimizer}
452
+ onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
453
+ options={[
454
+ { value: 'adamw8bit', label: 'AdamW8Bit' },
455
+ { value: 'adafactor', label: 'Adafactor' },
456
+ ]}
457
+ />
458
+ <NumberInput
459
+ label="Learning Rate"
460
+ className="pt-2"
461
+ value={jobConfig.config.process[0].train.lr}
462
+ onChange={value => setJobConfig(value, 'config.process[0].train.lr')}
463
+ placeholder="eg. 0.0001"
464
+ min={0}
465
+ required
466
+ />
467
+ <NumberInput
468
+ label="Weight Decay"
469
+ className="pt-2"
470
+ value={jobConfig.config.process[0].train.optimizer_params.weight_decay}
471
+ onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')}
472
+ placeholder="eg. 0.0001"
473
+ min={0}
474
+ required
475
+ />
476
+ </div>
477
+ <div>
478
+ {modelArch?.disableSections?.includes('train.timestep_type') ? null : (
479
+ <SelectInput
480
+ label="Timestep Type"
481
+ value={jobConfig.config.process[0].train.timestep_type}
482
+ disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
483
+ onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
484
+ options={[
485
+ { value: 'sigmoid', label: 'Sigmoid' },
486
+ { value: 'linear', label: 'Linear' },
487
+ { value: 'shift', label: 'Shift' },
488
+ { value: 'weighted', label: 'Weighted' },
489
+ ]}
490
+ />
491
+ )}
492
+ <SelectInput
493
+ label="Timestep Bias"
494
+ className="pt-2"
495
+ value={jobConfig.config.process[0].train.content_or_style}
496
+ onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')}
497
+ options={[
498
+ { value: 'balanced', label: 'Balanced' },
499
+ { value: 'content', label: 'High Noise' },
500
+ { value: 'style', label: 'Low Noise' },
501
+ ]}
502
+ />
503
+ <SelectInput
504
+ label="Noise Scheduler"
505
+ className="pt-2"
506
+ value={jobConfig.config.process[0].train.noise_scheduler}
507
+ onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
508
+ options={[
509
+ { value: 'flowmatch', label: 'FlowMatch' },
510
+ { value: 'ddpm', label: 'DDPM' },
511
+ ]}
512
+ />
513
+ </div>
514
+ <div>
515
+ <FormGroup label="EMA (Exponential Moving Average)">
516
+ <Checkbox
517
+ label="Use EMA"
518
+ className="pt-1"
519
+ checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
520
+ onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
521
+ />
522
+ </FormGroup>
523
+ {jobConfig.config.process[0].train.ema_config?.use_ema && (
524
+ <NumberInput
525
+ label="EMA Decay"
526
+ className="pt-2"
527
+ value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
528
+ onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
529
+ placeholder="eg. 0.99"
530
+ min={0}
531
+ />
532
+ )}
533
+
534
+ <FormGroup label="Text Encoder Optimizations" className="pt-2">
535
+ <Checkbox
536
+ label="Unload TE"
537
+ checked={jobConfig.config.process[0].train.unload_text_encoder || false}
538
+ docKey={'train.unload_text_encoder'}
539
+ onChange={value => {
540
+ setJobConfig(value, 'config.process[0].train.unload_text_encoder');
541
+ if (value) {
542
+ setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
543
+ }
544
+ }}
545
+ />
546
+ <Checkbox
547
+ label="Cache Text Embeddings"
548
+ checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
549
+ docKey={'train.cache_text_embeddings'}
550
+ onChange={value => {
551
+ setJobConfig(value, 'config.process[0].train.cache_text_embeddings');
552
+ if (value) {
553
+ setJobConfig(false, 'config.process[0].train.unload_text_encoder');
554
+ }
555
+ }}
556
+ />
557
+ </FormGroup>
558
+ </div>
559
+ <div>
560
+ <FormGroup label="Regularization">
561
+ <Checkbox
562
+ label="Differtial Output Preservation"
563
+ className="pt-1"
564
+ checked={jobConfig.config.process[0].train.diff_output_preservation || false}
565
+ onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
566
+ />
567
+ </FormGroup>
568
+ {jobConfig.config.process[0].train.diff_output_preservation && (
569
+ <>
570
+ <NumberInput
571
+ label="DOP Loss Multiplier"
572
+ className="pt-2"
573
+ value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
574
+ onChange={value =>
575
+ setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
576
+ }
577
+ placeholder="eg. 1.0"
578
+ min={0}
579
+ />
580
+ <TextInput
581
+ label="DOP Preservation Class"
582
+ className="pt-2"
583
+ value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
584
+ onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
585
+ placeholder="eg. woman"
586
+ />
587
+ </>
588
+ )}
589
+ </div>
590
+ </div>
591
+ </Card>
592
+ </div>
593
+ <div>
594
+ <Card title="Datasets">
595
+ <>
596
+ {jobConfig.config.process[0].datasets.map((dataset, i) => (
597
+ <div key={i} className="p-4 rounded-lg bg-gray-800 relative">
598
+ <button
599
+ type="button"
600
+ onClick={() =>
601
+ setJobConfig(
602
+ jobConfig.config.process[0].datasets.filter((_, index) => index !== i),
603
+ 'config.process[0].datasets',
604
+ )
605
+ }
606
+ className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors"
607
+ >
608
+ <X />
609
+ </button>
610
+ <h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2>
611
+ <div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
612
+ <div>
613
+ <SelectInput
614
+ label="Dataset"
615
+ value={dataset.folder_path}
616
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
617
+ options={datasetOptions}
618
+ />
619
+ {modelArch?.additionalSections?.includes('datasets.control_path') && (
620
+ <SelectInput
621
+ label="Control Dataset"
622
+ docKey="datasets.control_path"
623
+ value={dataset.control_path ?? ''}
624
+ className="pt-2"
625
+ onChange={value =>
626
+ setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
627
+ }
628
+ options={[{ value: '', label: <>&nbsp;</> }, ...datasetOptions]}
629
+ />
630
+ )}
631
+ <NumberInput
632
+ label="LoRA Weight"
633
+ value={dataset.network_weight}
634
+ className="pt-2"
635
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)}
636
+ placeholder="eg. 1.0"
637
+ />
638
+ </div>
639
+ <div>
640
+ <TextInput
641
+ label="Default Caption"
642
+ value={dataset.default_caption}
643
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
644
+ placeholder="eg. A photo of a cat"
645
+ />
646
+ <NumberInput
647
+ label="Caption Dropout Rate"
648
+ className="pt-2"
649
+ value={dataset.caption_dropout_rate}
650
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)}
651
+ placeholder="eg. 0.05"
652
+ min={0}
653
+ required
654
+ />
655
+ {modelArch?.additionalSections?.includes('datasets.num_frames') && (
656
+ <NumberInput
657
+ label="Num Frames"
658
+ className="pt-2"
659
+ docKey="datasets.num_frames"
660
+ value={dataset.num_frames}
661
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)}
662
+ placeholder="eg. 41"
663
+ min={1}
664
+ required
665
+ />
666
+ )}
667
+ </div>
668
+ <div>
669
+ <FormGroup label="Settings" className="">
670
+ <Checkbox
671
+ label="Cache Latents"
672
+ checked={dataset.cache_latents_to_disk || false}
673
+ onChange={value =>
674
+ setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
675
+ }
676
+ />
677
+ <Checkbox
678
+ label="Is Regularization"
679
+ checked={dataset.is_reg || false}
680
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
681
+ />
682
+ {modelArch?.additionalSections?.includes('datasets.do_i2v') && (
683
+ <Checkbox
684
+ label="Do I2V"
685
+ checked={dataset.do_i2v || false}
686
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
687
+ docKey="datasets.do_i2v"
688
+ />
689
+ )}
690
+ </FormGroup>
691
+ <FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
692
+ <Checkbox
693
+ label={<>Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" /></>}
694
+ checked={dataset.flip_x || false}
695
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)}
696
+ />
697
+ <Checkbox
698
+ label={<>Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" /></>}
699
+ checked={dataset.flip_y || false}
700
+ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)}
701
+ />
702
+ </FormGroup>
703
+ </div>
704
+ <div>
705
+ <FormGroup label="Resolutions" className="pt-2">
706
+ <div className="grid grid-cols-2 gap-2">
707
+ {[
708
+ [256, 512, 768],
709
+ [1024, 1280, 1536],
710
+ ].map(resGroup => (
711
+ <div key={resGroup[0]} className="space-y-2">
712
+ {resGroup.map(res => (
713
+ <Checkbox
714
+ key={res}
715
+ label={res.toString()}
716
+ checked={dataset.resolution.includes(res)}
717
+ onChange={value => {
718
+ const resolutions = dataset.resolution.includes(res)
719
+ ? dataset.resolution.filter(r => r !== res)
720
+ : [...dataset.resolution, res];
721
+ setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
722
+ }}
723
+ />
724
+ ))}
725
+ </div>
726
+ ))}
727
+ </div>
728
+ </FormGroup>
729
+ </div>
730
+ </div>
731
+ </div>
732
+ ))}
733
+ <button
734
+ type="button"
735
+ onClick={() => {
736
+ const newDataset = objectCopy(defaultDatasetConfig);
737
+ // automaticallt add the controls for a new dataset
738
+ const controls = modelArch?.controls ?? [];
739
+ newDataset.controls = controls;
740
+ setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets');
741
+ }}
742
+ className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
743
+ >
744
+ Add Dataset
745
+ </button>
746
+ </>
747
+ </Card>
748
+ </div>
749
+ <div>
750
+ <Card title="Sample">
751
+ <div
752
+ className={
753
+ isVideoModel
754
+ ? 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6'
755
+ : 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6'
756
+ }
757
+ >
758
+ <div>
759
+ <NumberInput
760
+ label="Sample Every"
761
+ value={jobConfig.config.process[0].sample.sample_every}
762
+ onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')}
763
+ placeholder="eg. 250"
764
+ min={1}
765
+ required
766
+ />
767
+ <SelectInput
768
+ label="Sampler"
769
+ className="pt-2"
770
+ value={jobConfig.config.process[0].sample.sampler}
771
+ onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
772
+ options={[
773
+ { value: 'flowmatch', label: 'FlowMatch' },
774
+ { value: 'ddpm', label: 'DDPM' },
775
+ ]}
776
+ />
777
+ <NumberInput
778
+ label="Guidance Scale"
779
+ value={jobConfig.config.process[0].sample.guidance_scale}
780
+ onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
781
+ placeholder="eg. 1.0"
782
+ className="pt-2"
783
+ min={0}
784
+ required
785
+ />
786
+ <NumberInput
787
+ label="Sample Steps"
788
+ value={jobConfig.config.process[0].sample.sample_steps}
789
+ onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')}
790
+ placeholder="eg. 1"
791
+ className="pt-2"
792
+ min={1}
793
+ required
794
+ />
795
+ </div>
796
+ <div>
797
+ <NumberInput
798
+ label="Width"
799
+ value={jobConfig.config.process[0].sample.width}
800
+ onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
801
+ placeholder="eg. 1024"
802
+ min={0}
803
+ required
804
+ />
805
+ <NumberInput
806
+ label="Height"
807
+ value={jobConfig.config.process[0].sample.height}
808
+ onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
809
+ placeholder="eg. 1024"
810
+ className="pt-2"
811
+ min={0}
812
+ required
813
+ />
814
+ {isVideoModel && (
815
+ <div>
816
+ <NumberInput
817
+ label="Num Frames"
818
+ value={jobConfig.config.process[0].sample.num_frames}
819
+ onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
820
+ placeholder="eg. 0"
821
+ className="pt-2"
822
+ min={0}
823
+ required
824
+ />
825
+ <NumberInput
826
+ label="FPS"
827
+ value={jobConfig.config.process[0].sample.fps}
828
+ onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
829
+ placeholder="eg. 0"
830
+ className="pt-2"
831
+ min={0}
832
+ required
833
+ />
834
+ </div>
835
+ )}
836
+ </div>
837
+
838
+ <div>
839
+ <NumberInput
840
+ label="Seed"
841
+ value={jobConfig.config.process[0].sample.seed}
842
+ onChange={value => setJobConfig(value, 'config.process[0].sample.seed')}
843
+ placeholder="eg. 0"
844
+ min={0}
845
+ required
846
+ />
847
+ <Checkbox
848
+ label="Walk Seed"
849
+ className="pt-4 pl-2"
850
+ checked={jobConfig.config.process[0].sample.walk_seed}
851
+ onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
852
+ />
853
+ </div>
854
+ <div>
855
+ <FormGroup label="Advanced Sampling" className="pt-2">
856
+ <div>
857
+ <Checkbox
858
+ label="Skip First Sample"
859
+ className="pt-4"
860
+ checked={jobConfig.config.process[0].train.skip_first_sample || false}
861
+ onChange={value => setJobConfig(value, 'config.process[0].train.skip_first_sample')}
862
+ />
863
+ </div>
864
+ <div>
865
+ <Checkbox
866
+ label="Disable Sampling"
867
+ className="pt-1"
868
+ checked={jobConfig.config.process[0].train.disable_sampling || false}
869
+ onChange={value => setJobConfig(value, 'config.process[0].train.disable_sampling')}
870
+ />
871
+ </div>
872
+ </FormGroup>
873
+ </div>
874
+ </div>
875
+ <FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.samples.length})`} className="pt-2">
876
+ <div></div>
877
+ </FormGroup>
878
+ {jobConfig.config.process[0].sample.samples.map((sample, i) => (
879
+ <div key={i} className="rounded-lg pl-4 pr-1 mb-4 bg-gray-950">
880
+ <div className="flex items-center space-x-2">
881
+ <div className="flex-1">
882
+ <div className="flex">
883
+ <div className="flex-1">
884
+ <TextInput
885
+ label={`Prompt`}
886
+ value={sample.prompt}
887
+ onChange={value => setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)}
888
+ placeholder="Enter prompt"
889
+ required
890
+ />
891
+ </div>
892
+
893
+ {modelArch?.additionalSections?.includes('sample.ctrl_img') && (
894
+ <div
895
+ className="h-14 w-14 mt-2 ml-4 border border-gray-500 flex items-center justify-center rounded cursor-pointer hover:bg-gray-700 transition-colors"
896
+ style={{
897
+ backgroundImage: sample.ctrl_img
898
+ ? `url(${`/api/img/${encodeURIComponent(sample.ctrl_img)}`})`
899
+ : 'none',
900
+ backgroundSize: 'cover',
901
+ backgroundPosition: 'center',
902
+ marginBottom: '-1rem',
903
+ }}
904
+ onClick={() => {
905
+ openAddImageModal(imagePath => {
906
+ console.log('Selected image path:', imagePath);
907
+ if (!imagePath) return;
908
+ setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`);
909
+ });
910
+ }}
911
+ >
912
+ {!sample.ctrl_img && (
913
+ <div className="text-gray-400 text-xs text-center font-bold">Add Control Image</div>
914
+ )}
915
+ </div>
916
+ )}
917
+ </div>
918
+ <div className="pb-4"></div>
919
+ </div>
920
+ <div>
921
+ <button
922
+ type="button"
923
+ onClick={() =>
924
+ setJobConfig(
925
+ jobConfig.config.process[0].sample.samples.filter((_, index) => index !== i),
926
+ 'config.process[0].sample.samples',
927
+ )
928
+ }
929
+ className="rounded-full p-1 text-sm"
930
+ >
931
+ <X />
932
+ </button>
933
+ </div>
934
+ </div>
935
+ </div>
936
+ ))}
937
+ <button
938
+ type="button"
939
+ onClick={() =>
940
+ setJobConfig(
941
+ [...jobConfig.config.process[0].sample.samples, { prompt: '' }],
942
+ 'config.process[0].sample.samples',
943
+ )
944
+ }
945
+ className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
946
+ >
947
+ Add Prompt
948
+ </button>
949
+ </Card>
950
+ </div>
951
+
952
+ {status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
953
+ {status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
954
+ </form>
955
+
956
+ {trainingBackend === 'hf-jobs' && (
957
+ <div className="mt-8">
958
+ <HFJobsWorkflow
959
+ jobConfig={jobConfig}
960
+ onComplete={(jobId, localJobId) => {
961
+ console.log('HF Job submitted:', jobId, 'Local job ID:', localJobId);
962
+ if (onHFJobComplete) {
963
+ onHFJobComplete(jobId, localJobId);
964
+ }
965
+ }}
966
+ />
967
+ </div>
968
+ )}
969
+
970
+ <AddSingleImageModal />
971
+ </>
972
+ );
973
+ }
src/app/jobs/new/jobConfig.ts ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { JobConfig, DatasetConfig } from '@/types';
2
+
3
+ export const defaultDatasetConfig: DatasetConfig = {
4
+ folder_path: '/path/to/images/folder',
5
+ control_path: null,
6
+ mask_path: null,
7
+ mask_min_value: 0.1,
8
+ default_caption: '',
9
+ caption_ext: 'txt',
10
+ caption_dropout_rate: 0.05,
11
+ cache_latents_to_disk: false,
12
+ is_reg: false,
13
+ network_weight: 1,
14
+ resolution: [512, 768, 1024],
15
+ controls: [],
16
+ shrink_video_to_frames: true,
17
+ num_frames: 1,
18
+ do_i2v: true,
19
+ flip_x: false,
20
+ flip_y: false,
21
+ };
22
+
23
+ export const defaultJobConfig: JobConfig = {
24
+ job: 'extension',
25
+ config: {
26
+ name: 'my_first_lora_v1',
27
+ process: [
28
+ {
29
+ type: 'ui_trainer',
30
+ training_folder: 'output',
31
+ sqlite_db_path: './aitk_db.db',
32
+ device: 'cuda',
33
+ trigger_word: null,
34
+ performance_log_every: 10,
35
+ network: {
36
+ type: 'lora',
37
+ linear: 32,
38
+ linear_alpha: 32,
39
+ conv: 16,
40
+ conv_alpha: 16,
41
+ lokr_full_rank: true,
42
+ lokr_factor: -1,
43
+ network_kwargs: {
44
+ ignore_if_contains: [],
45
+ },
46
+ },
47
+ save: {
48
+ dtype: 'bf16',
49
+ save_every: 250,
50
+ max_step_saves_to_keep: 4,
51
+ save_format: 'diffusers',
52
+ push_to_hub: false,
53
+ },
54
+ datasets: [defaultDatasetConfig],
55
+ train: {
56
+ batch_size: 1,
57
+ bypass_guidance_embedding: true,
58
+ steps: 3000,
59
+ gradient_accumulation: 1,
60
+ train_unet: true,
61
+ train_text_encoder: false,
62
+ gradient_checkpointing: true,
63
+ noise_scheduler: 'flowmatch',
64
+ optimizer: 'adamw8bit',
65
+ timestep_type: 'sigmoid',
66
+ content_or_style: 'balanced',
67
+ optimizer_params: {
68
+ weight_decay: 1e-4,
69
+ },
70
+ unload_text_encoder: false,
71
+ cache_text_embeddings: false,
72
+ lr: 0.0001,
73
+ ema_config: {
74
+ use_ema: false,
75
+ ema_decay: 0.99,
76
+ },
77
+ skip_first_sample: false,
78
+ disable_sampling: false,
79
+ dtype: 'bf16',
80
+ diff_output_preservation: false,
81
+ diff_output_preservation_multiplier: 1.0,
82
+ diff_output_preservation_class: 'person',
83
+ switch_boundary_every: 1,
84
+ },
85
+ model: {
86
+ name_or_path: 'ostris/Flex.1-alpha',
87
+ quantize: true,
88
+ qtype: 'qfloat8',
89
+ quantize_te: true,
90
+ qtype_te: 'qfloat8',
91
+ arch: 'flex1',
92
+ low_vram: false,
93
+ model_kwargs: {},
94
+ },
95
+ sample: {
96
+ sampler: 'flowmatch',
97
+ sample_every: 250,
98
+ width: 1024,
99
+ height: 1024,
100
+ samples: [
101
+ {
102
+ prompt: 'woman with red hair, playing chess at the park, bomb going off in the background'
103
+ },
104
+ {
105
+ prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe',
106
+ },
107
+ {
108
+ prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
109
+ },
110
+ {
111
+ prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
112
+ },
113
+ {
114
+ prompt: 'a bear building a log cabin in the snow covered mountains',
115
+ },
116
+ {
117
+ prompt: 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
118
+ },
119
+ {
120
+ prompt: 'hipster man with a beard, building a chair, in a wood shop',
121
+ },
122
+ {
123
+ prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
124
+ },
125
+ {
126
+ prompt: "a man holding a sign that says, 'this is a sign'",
127
+ },
128
+ {
129
+ prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
130
+ },
131
+ ],
132
+ neg: '',
133
+ seed: 42,
134
+ walk_seed: true,
135
+ guidance_scale: 4,
136
+ sample_steps: 25,
137
+ num_frames: 1,
138
+ fps: 1,
139
+ },
140
+ },
141
+ ],
142
+ },
143
+ meta: {
144
+ name: '[name]',
145
+ version: '1.0',
146
+ },
147
+ };
148
+
149
+ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
150
+ // upgrade prompt strings to samples
151
+ if (
152
+ jobConfig?.config?.process &&
153
+ jobConfig.config.process[0]?.sample &&
154
+ Array.isArray(jobConfig.config.process[0].sample.prompts) &&
155
+ jobConfig.config.process[0].sample.prompts.length > 0
156
+ ) {
157
+ let newSamples = [];
158
+ for (const prompt of jobConfig.config.process[0].sample.prompts) {
159
+ newSamples.push({
160
+ prompt: prompt,
161
+ });
162
+ }
163
+ jobConfig.config.process[0].sample.samples = newSamples;
164
+ delete jobConfig.config.process[0].sample.prompts;
165
+ }
166
+ return jobConfig;
167
+ };
src/app/jobs/new/options.ts ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { GroupedSelectOption, SelectOption } from '@/types';
2
+
3
+ type Control = 'depth' | 'line' | 'pose' | 'inpaint';
4
+
5
+ type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
6
+ type AdditionalSections =
7
+ | 'datasets.control_path'
8
+ | 'datasets.do_i2v'
9
+ | 'sample.ctrl_img'
10
+ | 'datasets.num_frames'
11
+ | 'model.multistage'
12
+ | 'model.low_vram';
13
+ type ModelGroup = 'image' | 'instruction' | 'video';
14
+
15
+ export interface ModelArch {
16
+ name: string;
17
+ label: string;
18
+ group: ModelGroup;
19
+ controls?: Control[];
20
+ isVideoModel?: boolean;
21
+ defaults?: { [key: string]: any };
22
+ disableSections?: DisableableSections[];
23
+ additionalSections?: AdditionalSections[];
24
+ accuracyRecoveryAdapters?: { [key: string]: string };
25
+ }
26
+
27
+ const defaultNameOrPath = '';
28
+
29
+ export const modelArchs: ModelArch[] = [
30
+ {
31
+ name: 'flux',
32
+ label: 'FLUX.1',
33
+ group: 'image',
34
+ defaults: {
35
+ // default updates when [selected, unselected] in the UI
36
+ 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath],
37
+ 'config.process[0].model.quantize': [true, false],
38
+ 'config.process[0].model.quantize_te': [true, false],
39
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
40
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
41
+ },
42
+ disableSections: ['network.conv'],
43
+ },
44
+ {
45
+ name: 'flux_kontext',
46
+ label: 'FLUX.1-Kontext-dev',
47
+ group: 'instruction',
48
+ defaults: {
49
+ // default updates when [selected, unselected] in the UI
50
+ 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
51
+ 'config.process[0].model.quantize': [true, false],
52
+ 'config.process[0].model.quantize_te': [true, false],
53
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
54
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
55
+ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
56
+ },
57
+ disableSections: ['network.conv'],
58
+ additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
59
+ },
60
+ {
61
+ name: 'flex1',
62
+ label: 'Flex.1',
63
+ group: 'image',
64
+ defaults: {
65
+ // default updates when [selected, unselected] in the UI
66
+ 'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath],
67
+ 'config.process[0].model.quantize': [true, false],
68
+ 'config.process[0].model.quantize_te': [true, false],
69
+ 'config.process[0].train.bypass_guidance_embedding': [true, false],
70
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
71
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
72
+ },
73
+ disableSections: ['network.conv'],
74
+ },
75
+ {
76
+ name: 'flex2',
77
+ label: 'Flex.2',
78
+ group: 'image',
79
+ controls: ['depth', 'line', 'pose', 'inpaint'],
80
+ defaults: {
81
+ // default updates when [selected, unselected] in the UI
82
+ 'config.process[0].model.name_or_path': ['ostris/Flex.2-preview', defaultNameOrPath],
83
+ 'config.process[0].model.quantize': [true, false],
84
+ 'config.process[0].model.quantize_te': [true, false],
85
+ 'config.process[0].model.model_kwargs': [
86
+ {
87
+ invert_inpaint_mask_chance: 0.2,
88
+ inpaint_dropout: 0.5,
89
+ control_dropout: 0.5,
90
+ inpaint_random_chance: 0.2,
91
+ do_random_inpainting: true,
92
+ random_blur_mask: true,
93
+ random_dialate_mask: true,
94
+ },
95
+ {},
96
+ ],
97
+ 'config.process[0].train.bypass_guidance_embedding': [true, false],
98
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
99
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
100
+ },
101
+ disableSections: ['network.conv'],
102
+ },
103
+ {
104
+ name: 'chroma',
105
+ label: 'Chroma',
106
+ group: 'image',
107
+ defaults: {
108
+ // default updates when [selected, unselected] in the UI
109
+ 'config.process[0].model.name_or_path': ['lodestones/Chroma1-Base', defaultNameOrPath],
110
+ 'config.process[0].model.quantize': [true, false],
111
+ 'config.process[0].model.quantize_te': [true, false],
112
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
113
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
114
+ },
115
+ disableSections: ['network.conv'],
116
+ },
117
+ {
118
+ name: 'wan21:1b',
119
+ label: 'Wan 2.1 (1.3B)',
120
+ group: 'video',
121
+ isVideoModel: true,
122
+ defaults: {
123
+ // default updates when [selected, unselected] in the UI
124
+ 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-1.3B-Diffusers', defaultNameOrPath],
125
+ 'config.process[0].model.quantize': [false, false],
126
+ 'config.process[0].model.quantize_te': [true, false],
127
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
128
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
129
+ 'config.process[0].sample.num_frames': [41, 1],
130
+ 'config.process[0].sample.fps': [16, 1],
131
+ },
132
+ disableSections: ['network.conv'],
133
+ additionalSections: ['datasets.num_frames', 'model.low_vram'],
134
+ },
135
+ {
136
+ name: 'wan21_i2v:14b480p',
137
+ label: 'Wan 2.1 I2V (14B-480P)',
138
+ group: 'video',
139
+ isVideoModel: true,
140
+ defaults: {
141
+ // default updates when [selected, unselected] in the UI
142
+ 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath],
143
+ 'config.process[0].model.quantize': [true, false],
144
+ 'config.process[0].model.quantize_te': [true, false],
145
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
146
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
147
+ 'config.process[0].sample.num_frames': [41, 1],
148
+ 'config.process[0].sample.fps': [16, 1],
149
+ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
150
+ },
151
+ disableSections: ['network.conv'],
152
+ additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
153
+ },
154
+ {
155
+ name: 'wan21_i2v:14b',
156
+ label: 'Wan 2.1 I2V (14B-720P)',
157
+ group: 'video',
158
+ isVideoModel: true,
159
+ defaults: {
160
+ // default updates when [selected, unselected] in the UI
161
+ 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath],
162
+ 'config.process[0].model.quantize': [true, false],
163
+ 'config.process[0].model.quantize_te': [true, false],
164
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
165
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
166
+ 'config.process[0].sample.num_frames': [41, 1],
167
+ 'config.process[0].sample.fps': [16, 1],
168
+ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
169
+ },
170
+ disableSections: ['network.conv'],
171
+ additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
172
+ },
173
+ {
174
+ name: 'wan21:14b',
175
+ label: 'Wan 2.1 (14B)',
176
+ group: 'video',
177
+ isVideoModel: true,
178
+ defaults: {
179
+ // default updates when [selected, unselected] in the UI
180
+ 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffusers', defaultNameOrPath],
181
+ 'config.process[0].model.quantize': [true, false],
182
+ 'config.process[0].model.quantize_te': [true, false],
183
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
184
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
185
+ 'config.process[0].sample.num_frames': [41, 1],
186
+ 'config.process[0].sample.fps': [16, 1],
187
+ },
188
+ disableSections: ['network.conv'],
189
+ additionalSections: ['datasets.num_frames', 'model.low_vram'],
190
+ },
191
+ {
192
+ name: 'wan22_14b:t2v',
193
+ label: 'Wan 2.2 (14B)',
194
+ group: 'video',
195
+ isVideoModel: true,
196
+ defaults: {
197
+ // default updates when [selected, unselected] in the UI
198
+ 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath],
199
+ 'config.process[0].model.quantize': [true, false],
200
+ 'config.process[0].model.quantize_te': [true, false],
201
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
202
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
203
+ 'config.process[0].sample.num_frames': [41, 1],
204
+ 'config.process[0].sample.fps': [16, 1],
205
+ 'config.process[0].model.low_vram': [true, false],
206
+ 'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
207
+ 'config.process[0].model.model_kwargs': [
208
+ {
209
+ train_high_noise: true,
210
+ train_low_noise: true,
211
+ },
212
+ {},
213
+ ],
214
+ },
215
+ disableSections: ['network.conv'],
216
+ additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
217
+ accuracyRecoveryAdapters: {
218
+ // '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
219
+ '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors',
220
+ },
221
+ },
222
+ {
223
+ name: 'wan22_14b_i2v',
224
+ label: 'Wan 2.2 I2V (14B)',
225
+ group: 'video',
226
+ isVideoModel: true,
227
+ defaults: {
228
+ // default updates when [selected, unselected] in the UI
229
+ 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath],
230
+ 'config.process[0].model.quantize': [true, false],
231
+ 'config.process[0].model.quantize_te': [true, false],
232
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
233
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
234
+ 'config.process[0].sample.num_frames': [41, 1],
235
+ 'config.process[0].sample.fps': [16, 1],
236
+ 'config.process[0].model.low_vram': [true, false],
237
+ 'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
238
+ 'config.process[0].model.model_kwargs': [
239
+ {
240
+ train_high_noise: true,
241
+ train_low_noise: true,
242
+ },
243
+ {},
244
+ ],
245
+ },
246
+ disableSections: ['network.conv'],
247
+ additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'],
248
+ accuracyRecoveryAdapters: {
249
+ '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors',
250
+ },
251
+ },
252
+ {
253
+ name: 'wan22_5b',
254
+ label: 'Wan 2.2 TI2V (5B)',
255
+ group: 'video',
256
+ isVideoModel: true,
257
+ defaults: {
258
+ // default updates when [selected, unselected] in the UI
259
+ 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath],
260
+ 'config.process[0].model.quantize': [true, false],
261
+ 'config.process[0].model.quantize_te': [true, false],
262
+ 'config.process[0].model.low_vram': [true, false],
263
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
264
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
265
+ 'config.process[0].sample.num_frames': [121, 1],
266
+ 'config.process[0].sample.fps': [24, 1],
267
+ 'config.process[0].sample.width': [768, 1024],
268
+ 'config.process[0].sample.height': [768, 1024],
269
+ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
270
+ },
271
+ disableSections: ['network.conv'],
272
+ additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
273
+ },
274
+ {
275
+ name: 'lumina2',
276
+ label: 'Lumina2',
277
+ group: 'image',
278
+ defaults: {
279
+ // default updates when [selected, unselected] in the UI
280
+ 'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath],
281
+ 'config.process[0].model.quantize': [false, false],
282
+ 'config.process[0].model.quantize_te': [true, false],
283
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
284
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
285
+ },
286
+ disableSections: ['network.conv'],
287
+ },
288
+ {
289
+ name: 'qwen_image',
290
+ label: 'Qwen-Image',
291
+ group: 'image',
292
+ defaults: {
293
+ // default updates when [selected, unselected] in the UI
294
+ 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image', defaultNameOrPath],
295
+ 'config.process[0].model.quantize': [true, false],
296
+ 'config.process[0].model.quantize_te': [true, false],
297
+ 'config.process[0].model.low_vram': [true, false],
298
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
299
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
300
+ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
301
+ 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
302
+ },
303
+ disableSections: ['network.conv'],
304
+ additionalSections: ['model.low_vram'],
305
+ accuracyRecoveryAdapters: {
306
+ '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
307
+ },
308
+ },
309
+ {
310
+ name: 'qwen_image_edit',
311
+ label: 'Qwen-Image-Edit',
312
+ group: 'instruction',
313
+ defaults: {
314
+ // default updates when [selected, unselected] in the UI
315
+ 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath],
316
+ 'config.process[0].model.quantize': [true, false],
317
+ 'config.process[0].model.quantize_te': [true, false],
318
+ 'config.process[0].model.low_vram': [true, false],
319
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
320
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
321
+ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
322
+ 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
323
+ },
324
+ disableSections: ['network.conv'],
325
+ additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'],
326
+ accuracyRecoveryAdapters: {
327
+ '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
328
+ },
329
+ },
330
+ {
331
+ name: 'hidream',
332
+ label: 'HiDream',
333
+ group: 'image',
334
+ defaults: {
335
+ // default updates when [selected, unselected] in the UI
336
+ 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath],
337
+ 'config.process[0].model.quantize': [true, false],
338
+ 'config.process[0].model.quantize_te': [true, false],
339
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
340
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
341
+ 'config.process[0].train.lr': [0.0002, 0.0001],
342
+ 'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
343
+ 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
344
+ },
345
+ disableSections: ['network.conv'],
346
+ additionalSections: ['model.low_vram'],
347
+ },
348
+ {
349
+ name: 'hidream_e1',
350
+ label: 'HiDream E1',
351
+ group: 'instruction',
352
+ defaults: {
353
+ // default updates when [selected, unselected] in the UI
354
+ 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath],
355
+ 'config.process[0].model.quantize': [true, false],
356
+ 'config.process[0].model.quantize_te': [true, false],
357
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
358
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
359
+ 'config.process[0].train.lr': [0.0001, 0.0001],
360
+ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
361
+ 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
362
+ },
363
+ disableSections: ['network.conv'],
364
+ additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'],
365
+ },
366
+ {
367
+ name: 'sdxl',
368
+ label: 'SDXL',
369
+ group: 'image',
370
+ defaults: {
371
+ // default updates when [selected, unselected] in the UI
372
+ 'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
373
+ 'config.process[0].model.quantize': [false, false],
374
+ 'config.process[0].model.quantize_te': [false, false],
375
+ 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
376
+ 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
377
+ 'config.process[0].sample.guidance_scale': [6, 4],
378
+ },
379
+ disableSections: ['model.quantize', 'train.timestep_type'],
380
+ },
381
+ {
382
+ name: 'sd15',
383
+ label: 'SD 1.5',
384
+ group: 'image',
385
+ defaults: {
386
+ // default updates when [selected, unselected] in the UI
387
+ 'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
388
+ 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
389
+ 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
390
+ 'config.process[0].sample.width': [512, 1024],
391
+ 'config.process[0].sample.height': [512, 1024],
392
+ 'config.process[0].sample.guidance_scale': [6, 4],
393
+ },
394
+ disableSections: ['model.quantize', 'train.timestep_type'],
395
+ },
396
+ {
397
+ name: 'omnigen2',
398
+ label: 'OmniGen2',
399
+ group: 'image',
400
+ defaults: {
401
+ // default updates when [selected, unselected] in the UI
402
+ 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
403
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
404
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
405
+ 'config.process[0].model.quantize': [false, false],
406
+ 'config.process[0].model.quantize_te': [true, false],
407
+ },
408
+ disableSections: ['network.conv'],
409
+ additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
410
+ },
411
+ ].sort((a, b) => {
412
+ // Sort by label, case-insensitive
413
+ return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
414
+ }) as any;
415
+
416
+ export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
417
+ const group = acc.find(g => g.label === arch.group);
418
+ if (group) {
419
+ group.options.push({ value: arch.name, label: arch.label });
420
+ } else {
421
+ acc.push({
422
+ label: arch.group,
423
+ options: [{ value: arch.name, label: arch.label }],
424
+ });
425
+ }
426
+ return acc;
427
+ }, [] as GroupedSelectOption[]);
428
+
429
+ export const quantizationOptions: SelectOption[] = [
430
+ { value: '', label: '- NONE -' },
431
+ { value: 'qfloat8', label: 'float8 (default)' },
432
+ { value: 'uint8', label: '8 bit' },
433
+ { value: 'uint7', label: '7 bit' },
434
+ { value: 'uint6', label: '6 bit' },
435
+ { value: 'uint5', label: '5 bit' },
436
+ { value: 'uint4', label: '4 bit' },
437
+ { value: 'uint3', label: '3 bit' },
438
+ { value: 'uint2', label: '2 bit' },
439
+ ];
440
+
441
+ export const defaultQtype = 'qfloat8';
src/app/jobs/new/page.tsx ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useEffect, useState } from 'react';
4
+ import { useSearchParams, useRouter } from 'next/navigation';
5
+ import Link from 'next/link';
6
+ import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig';
7
+ import { JobConfig } from '@/types';
8
+ import { objectCopy } from '@/utils/basic';
9
+ import { useNestedState } from '@/utils/hooks';
10
+ import { SelectInput } from '@/components/formInputs';
11
+ import useSettings from '@/hooks/useSettings';
12
+ import useGPUInfo from '@/hooks/useGPUInfo';
13
+ import useDatasetList from '@/hooks/useDatasetList';
14
+ import path from 'path';
15
+ import { TopBar, MainContent } from '@/components/layout';
16
+ import { Button } from '@headlessui/react';
17
+ import { FaChevronLeft } from 'react-icons/fa';
18
+ import SimpleJob from './SimpleJob';
19
+ import AdvancedJob from './AdvancedJob';
20
+ import ErrorBoundary from '@/components/ErrorBoundary';
21
+ import { getJob, upsertJob } from '@/utils/storage/jobStorage';
22
+ import { usingBrowserDb } from '@/utils/env';
23
+ import { getUserDatasetPath, updateUserDatasetPath } from '@/utils/storage/datasetStorage';
24
+ import { apiClient } from '@/utils/api';
25
+ import { useAuth } from '@/contexts/AuthContext';
26
+ import HFLoginButton from '@/components/HFLoginButton';
27
+
28
+ const isDev = process.env.NODE_ENV === 'development';
29
+
30
+ export default function TrainingForm() {
31
+ const router = useRouter();
32
+ const searchParams = useSearchParams();
33
+ const runId = searchParams.get('id');
34
+ const { status: authStatus } = useAuth();
35
+ const isAuthenticated = authStatus === 'authenticated';
36
+ const [gpuIDs, setGpuIDs] = useState<string | null>(null);
37
+ const { settings, isSettingsLoaded } = useSettings();
38
+ const { gpuList, isGPUInfoLoaded } = useGPUInfo();
39
+ const { datasets, status: datasetFetchStatus } = useDatasetList();
40
+ const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
41
+ const [showAdvancedView, setShowAdvancedView] = useState(false);
42
+
43
+ const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
44
+ const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
45
+
46
+ // Track HF Jobs backend state
47
+ const [trainingBackend, setTrainingBackend] = useState<'local' | 'hf-jobs'>(
48
+ usingBrowserDb ? 'hf-jobs' : 'local',
49
+ );
50
+ const [hfJobSubmitted, setHfJobSubmitted] = useState(false);
51
+
52
+ useEffect(() => {
53
+ if (!isSettingsLoaded || !isAuthenticated) return;
54
+ if (datasetFetchStatus !== 'success') return;
55
+
56
+ let isMounted = true;
57
+
58
+ const buildDatasetOptions = async () => {
59
+ const options = await Promise.all(
60
+ datasets.map(async name => {
61
+ let datasetPath = settings.DATASETS_FOLDER ? path.join(settings.DATASETS_FOLDER, name) : '';
62
+
63
+ if (usingBrowserDb) {
64
+ const storedPath = getUserDatasetPath(name);
65
+ if (storedPath) {
66
+ datasetPath = storedPath;
67
+ } else {
68
+ try {
69
+ const response = await apiClient
70
+ .post('/api/datasets/create', { name })
71
+ .then(res => res.data);
72
+ if (response?.path) {
73
+ datasetPath = response.path;
74
+ updateUserDatasetPath(name, datasetPath);
75
+ }
76
+ } catch (err) {
77
+ console.error('Error resolving dataset path:', err);
78
+ }
79
+ }
80
+ }
81
+
82
+ if (!datasetPath) {
83
+ datasetPath = name;
84
+ }
85
+
86
+ return { value: datasetPath, label: name };
87
+ }),
88
+ );
89
+
90
+ if (!isMounted) {
91
+ return;
92
+ }
93
+
94
+ setDatasetOptions(options);
95
+ const defaultDatasetPath = defaultDatasetConfig.folder_path;
96
+
97
+ for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
98
+ const dataset = jobConfig.config.process[0].datasets[i];
99
+ if (dataset.folder_path === defaultDatasetPath) {
100
+ if (options.length > 0) {
101
+ setJobConfig(options[0].value, `config.process[0].datasets[${i}].folder_path`);
102
+ }
103
+ }
104
+ }
105
+ };
106
+
107
+ buildDatasetOptions();
108
+
109
+ return () => {
110
+ isMounted = false;
111
+ };
112
+ }, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
113
+
114
+ useEffect(() => {
115
+ if (runId) {
116
+ getJob(runId)
117
+ .then(data => {
118
+ if (!data) {
119
+ throw new Error('Job not found');
120
+ }
121
+ setGpuIDs(data.gpu_ids);
122
+ const parsedJobConfig = migrateJobConfig(JSON.parse(data.job_config));
123
+ setJobConfig(parsedJobConfig);
124
+
125
+ if (parsedJobConfig.is_hf_job) {
126
+ setTrainingBackend('hf-jobs');
127
+ setHfJobSubmitted(true);
128
+ }
129
+ })
130
+ .catch(error => console.error('Error fetching training:', error));
131
+ }
132
+ }, [runId]);
133
+
134
+ useEffect(() => {
135
+ if (isGPUInfoLoaded) {
136
+ if (gpuIDs === null && gpuList.length > 0) {
137
+ setGpuIDs(`${gpuList[0].index}`);
138
+ }
139
+ }
140
+ }, [gpuList, isGPUInfoLoaded]);
141
+
142
+ useEffect(() => {
143
+ if (isSettingsLoaded) {
144
+ setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder');
145
+ }
146
+ }, [settings, isSettingsLoaded]);
147
+
148
+ const saveJob = async () => {
149
+ if (!isAuthenticated) return;
150
+ if (status === 'saving') return;
151
+ setStatus('saving');
152
+
153
+ try {
154
+ const savedJob = await upsertJob({
155
+ id: runId || undefined,
156
+ name: jobConfig.config.name,
157
+ gpu_ids: gpuIDs,
158
+ job_config: {
159
+ ...jobConfig,
160
+ is_hf_job: trainingBackend === 'hf-jobs',
161
+ hf_job_submitted: hfJobSubmitted,
162
+ training_backend: trainingBackend,
163
+ },
164
+ status: trainingBackend === 'hf-jobs' ? (hfJobSubmitted ? 'submitted' : 'stopped') : undefined,
165
+ });
166
+
167
+ setStatus('success');
168
+ router.push(`/jobs/${savedJob.id}`);
169
+ } catch (error: any) {
170
+ console.log('Error saving training:', error);
171
+ if (error?.code === 'P2002') {
172
+ alert('Training name already exists. Please choose a different name.');
173
+ } else {
174
+ alert('Failed to save job. Please try again.');
175
+ }
176
+ } finally {
177
+ setTimeout(() => {
178
+ setStatus('idle');
179
+ }, 2000);
180
+ }
181
+ };
182
+
183
+ const handleSubmit = async (e: React.FormEvent) => {
184
+ e.preventDefault();
185
+ saveJob();
186
+ };
187
+
188
+ return (
189
+ <>
190
+ <TopBar>
191
+ <div>
192
+ <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
193
+ <FaChevronLeft />
194
+ </Button>
195
+ </div>
196
+ <div>
197
+ <h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
198
+ </div>
199
+ <div className="flex-1"></div>
200
+ {showAdvancedView && isAuthenticated && (
201
+ <>
202
+ <div>
203
+ <SelectInput
204
+ value={`${gpuIDs}`}
205
+ onChange={value => setGpuIDs(value)}
206
+ options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
207
+ />
208
+ </div>
209
+ <div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
210
+ </>
211
+ )}
212
+
213
+ <div className="pr-2">
214
+ <Button
215
+ className="text-gray-200 bg-gray-800 px-3 py-1 rounded-md"
216
+ onClick={() => setShowAdvancedView(!showAdvancedView)}
217
+ >
218
+ {showAdvancedView ? 'Show Simple' : 'Show Advanced'}
219
+ </Button>
220
+ </div>
221
+ <div>
222
+ <Button
223
+ className="text-gray-200 bg-green-800 hover:bg-green-700 px-3 py-1 rounded-md"
224
+ onClick={() => saveJob()}
225
+ disabled={!isAuthenticated || status === 'saving'}
226
+ >
227
+ {status === 'saving'
228
+ ? 'Saving...'
229
+ : runId
230
+ ? 'Update Job'
231
+ : 'Create Job'}
232
+ </Button>
233
+ </div>
234
+ </TopBar>
235
+
236
+ {!isAuthenticated ? (
237
+ <MainContent>
238
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
239
+ <p>You need to sign in with Hugging Face or provide a valid access token before creating or editing jobs.</p>
240
+ <div className="flex items-center gap-3">
241
+ <HFLoginButton size="sm" />
242
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
243
+ Manage authentication in Settings
244
+ </Link>
245
+ </div>
246
+ </div>
247
+ </MainContent>
248
+ ) : showAdvancedView ? (
249
+ <div className="pt-[48px] absolute top-0 left-0 w-full h-full overflow-auto">
250
+ <AdvancedJob
251
+ jobConfig={jobConfig}
252
+ setJobConfig={setJobConfig}
253
+ status={status}
254
+ handleSubmit={handleSubmit}
255
+ runId={runId}
256
+ gpuIDs={gpuIDs}
257
+ setGpuIDs={setGpuIDs}
258
+ gpuList={gpuList}
259
+ datasetOptions={datasetOptions}
260
+ settings={settings}
261
+ />
262
+ </div>
263
+ ) : (
264
+ <MainContent>
265
+ <ErrorBoundary
266
+ fallback={
267
+ <div className="flex items-center justify-center h-64 text-lg text-red-600 font-medium bg-red-100 dark:bg-red-900/20 dark:text-red-400 border border-red-300 dark:border-red-700 rounded-lg">
268
+ Advanced job detected. Please switch to advanced view to continue.
269
+ </div>
270
+ }
271
+ >
272
+ <SimpleJob
273
+ jobConfig={jobConfig}
274
+ setJobConfig={setJobConfig}
275
+ status={status}
276
+ handleSubmit={handleSubmit}
277
+ runId={runId}
278
+ gpuIDs={gpuIDs}
279
+ setGpuIDs={setGpuIDs}
280
+ gpuList={gpuList}
281
+ datasetOptions={datasetOptions}
282
+ trainingBackend={trainingBackend}
283
+ setTrainingBackend={usingBrowserDb ? undefined : setTrainingBackend}
284
+ hfJobSubmitted={hfJobSubmitted}
285
+ onHFJobComplete={(jobId: string, localJobId?: string) => {
286
+ setHfJobSubmitted(true);
287
+ // Redirect to the job detail page
288
+ if (localJobId) {
289
+ router.push(`/jobs/${localJobId}`);
290
+ }
291
+ }}
292
+ forceHFBackend={usingBrowserDb}
293
+ />
294
+ </ErrorBoundary>
295
+
296
+ <div className="pt-20"></div>
297
+ </MainContent>
298
+ )}
299
+ </>
300
+ );
301
+ }
302
+ useEffect(() => {
303
+ if (!isAuthenticated) {
304
+ setDatasetOptions([]);
305
+ }
306
+ }, [isAuthenticated]);
src/app/jobs/page.tsx ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import JobsTable from '@/components/JobsTable';
4
+ import { TopBar, MainContent } from '@/components/layout';
5
+ import Link from 'next/link';
6
+ import { useAuth } from '@/contexts/AuthContext';
7
+ import HFLoginButton from '@/components/HFLoginButton';
8
+
9
+ export default function Dashboard() {
10
+ const { status: authStatus } = useAuth();
11
+ const isAuthenticated = authStatus === 'authenticated';
12
+
13
+ return (
14
+ <>
15
+ <TopBar>
16
+ <div>
17
+ <h1 className="text-lg">Training Jobs</h1>
18
+ </div>
19
+ <div className="flex-1"></div>
20
+ <div>
21
+ {isAuthenticated ? (
22
+ <Link href="/jobs/new" className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md">
23
+ New Training Job
24
+ </Link>
25
+ ) : (
26
+ <span className="text-gray-600 bg-gray-900 px-3 py-1 rounded-md border border-gray-800">
27
+ Sign in to create jobs
28
+ </span>
29
+ )}
30
+ </div>
31
+ </TopBar>
32
+ <MainContent>
33
+ {isAuthenticated ? (
34
+ <JobsTable />
35
+ ) : (
36
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
37
+ <p>Sign in with Hugging Face or add a personal access token to view and manage training jobs.</p>
38
+ <div className="flex items-center gap-3">
39
+ <HFLoginButton size="sm" />
40
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
41
+ Manage tokens in Settings
42
+ </Link>
43
+ </div>
44
+ </div>
45
+ )}
46
+ </MainContent>
47
+ </>
48
+ );
49
+ }
src/app/layout.tsx ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { Metadata } from 'next';
2
+ import { Inter } from 'next/font/google';
3
+ import './globals.css';
4
+ import Sidebar from '@/components/Sidebar';
5
+ import { ThemeProvider } from '@/components/ThemeProvider';
6
+ import ConfirmModal from '@/components/ConfirmModal';
7
+ import SampleImageModal from '@/components/SampleImageModal';
8
+ import { Suspense } from 'react';
9
+ import AuthWrapper from '@/components/AuthWrapper';
10
+ import DocModal from '@/components/DocModal';
11
+ import { AuthProvider } from '@/contexts/AuthContext';
12
+
13
+ export const dynamic = 'force-dynamic';
14
+
15
+ const inter = Inter({ subsets: ['latin'] });
16
+
17
+ export const metadata: Metadata = {
18
+ title: 'Ostris - AI Toolkit',
19
+ description: 'A toolkit for building AI things.',
20
+ };
21
+
22
+ export default function RootLayout({ children }: { children: React.ReactNode }) {
23
+ // Check if the AI_TOOLKIT_AUTH environment variable is set
24
+ const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false;
25
+
26
+ return (
27
+ <html lang="en" className="dark">
28
+ <head>
29
+ <meta name="apple-mobile-web-app-title" content="AI-Toolkit" />
30
+ </head>
31
+ <body className={inter.className} suppressHydrationWarning={true}>
32
+ <ThemeProvider>
33
+ <AuthProvider>
34
+ <AuthWrapper authRequired={authRequired}>
35
+ <div className="flex h-screen bg-gray-950">
36
+ <Sidebar />
37
+ <main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
38
+ <Suspense>{children}</Suspense>
39
+ </main>
40
+ </div>
41
+ </AuthWrapper>
42
+ </AuthProvider>
43
+ </ThemeProvider>
44
+ <ConfirmModal />
45
+ <DocModal />
46
+ <SampleImageModal />
47
+ </body>
48
+ </html>
49
+ );
50
+ }
src/app/manifest.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "AI Toolkit",
3
+ "short_name": "AIToolkit",
4
+ "icons": [
5
+ {
6
+ "src": "/web-app-manifest-192x192.png",
7
+ "sizes": "192x192",
8
+ "type": "image/png",
9
+ "purpose": "maskable"
10
+ },
11
+ {
12
+ "src": "/web-app-manifest-512x512.png",
13
+ "sizes": "512x512",
14
+ "type": "image/png",
15
+ "purpose": "maskable"
16
+ }
17
+ ],
18
+ "theme_color": "#000000",
19
+ "background_color": "#000000",
20
+ "display": "standalone"
21
+ }
src/app/page.tsx ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import { redirect } from 'next/navigation';
2
+
3
+ export default function Home() {
4
+ redirect('/dashboard');
5
+ }
src/app/settings/page.tsx ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useEffect, useState } from 'react';
4
+ import useSettings from '@/hooks/useSettings';
5
+ import { TopBar, MainContent } from '@/components/layout';
6
+ import { persistSettings } from '@/utils/storage/settingsStorage';
7
+ import { useAuth } from '@/contexts/AuthContext';
8
+ import HFLoginButton from '@/components/HFLoginButton';
9
+ import { useMemo } from 'react';
10
+ import Link from 'next/link';
11
+
12
+ export default function Settings() {
13
+ const { settings, setSettings } = useSettings();
14
+ const { status: authStatus, namespace, oauthAvailable, loginWithOAuth, logout, setManualToken, error: authError, token: authToken } = useAuth();
15
+ const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
16
+ const [manualToken, setManualTokenInput] = useState(settings.HF_TOKEN || '');
17
+ const isAuthenticated = authStatus === 'authenticated';
18
+
19
+ useEffect(() => {
20
+ setManualTokenInput(settings.HF_TOKEN || '');
21
+ }, [settings.HF_TOKEN]);
22
+
23
+ const handleSubmit = async (e: React.FormEvent) => {
24
+ e.preventDefault();
25
+ setStatus('saving');
26
+
27
+ persistSettings(settings)
28
+ .then(() => {
29
+ setStatus('success');
30
+ })
31
+ .catch(error => {
32
+ console.error('Error saving settings:', error);
33
+ setStatus('error');
34
+ })
35
+ .finally(() => {
36
+ setTimeout(() => setStatus('idle'), 2000);
37
+ });
38
+ };
39
+
40
+ const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
41
+ const { name, value } = e.target;
42
+ setSettings(prev => ({ ...prev, [name]: value }));
43
+ };
44
+
45
+ const handleManualSubmit = async (e: React.FormEvent) => {
46
+ e.preventDefault();
47
+ await setManualToken(manualToken);
48
+ };
49
+
50
+ const authDescription = useMemo(() => {
51
+ if (authStatus === 'checking') {
52
+ return 'Checking your Hugging Face session…';
53
+ }
54
+ if (isAuthenticated) {
55
+ return `Connected as ${namespace}`;
56
+ }
57
+ return 'Sign in to use Hugging Face Jobs or submit your own access token.';
58
+ }, [authStatus, isAuthenticated, namespace]);
59
+
60
+ return (
61
+ <>
62
+ <TopBar>
63
+ <div>
64
+ <h1 className="text-lg">Settings</h1>
65
+ </div>
66
+ <div className="flex-1"></div>
67
+ <div className="flex items-center gap-3 pr-2 text-sm text-gray-400">
68
+ {isAuthenticated ? (
69
+ <span>Welcome, {namespace || 'user'}</span>
70
+ ) : (
71
+ <span>Authenticate to unlock training features</span>
72
+ )}
73
+ </div>
74
+ </TopBar>
75
+ <MainContent>
76
+ <div className="grid gap-4 md:grid-cols-2 mb-6">
77
+ <div className="border border-gray-800 rounded-xl p-5 bg-gray-900">
78
+ <div className="flex items-center justify-between mb-4">
79
+ <div>
80
+ <h2 className="text-md font-semibold text-gray-100">Sign in with Hugging Face</h2>
81
+ <p className="text-sm text-gray-400 mt-1">{authDescription}</p>
82
+ </div>
83
+ {isAuthenticated && (
84
+ <span className="text-xs px-2 py-1 rounded-full bg-emerald-900 text-emerald-300">Authenticated</span>
85
+ )}
86
+ </div>
87
+ <div className="flex items-center gap-3">
88
+ {isAuthenticated ? (
89
+ <button
90
+ type="button"
91
+ onClick={logout}
92
+ className="px-4 py-2 rounded-md border border-gray-700 text-sm bg-gray-800 hover:bg-gray-700 transition-colors"
93
+ >
94
+ Sign out
95
+ </button>
96
+ ) : (
97
+ <>
98
+ <HFLoginButton size="md" className="bg-transparent border-none p-0" />
99
+ {!oauthAvailable && (
100
+ <span className="text-xs text-yellow-500">
101
+ OAuth is unavailable. Set HF_OAUTH_CLIENT_ID/SECRET on the server.
102
+ </span>
103
+ )}
104
+ </>
105
+ )}
106
+ </div>
107
+ {!isAuthenticated && authError && (
108
+ <p className="mt-3 text-xs text-red-400">{authError}</p>
109
+ )}
110
+ </div>
111
+
112
+ <form onSubmit={handleManualSubmit} className="border border-gray-800 rounded-xl p-5 bg-gray-900">
113
+ <h2 className="text-md font-semibold text-gray-100">Manual Token</h2>
114
+ <p className="text-sm text-gray-400 mt-1">
115
+ Paste an access token created at{' '}
116
+ <a href="https://huggingface.co/settings/tokens" target="_blank" rel="noreferrer" className="text-blue-400 hover:text-blue-300">
117
+ huggingface.co/settings/tokens
118
+ </a>
119
+ .
120
+ </p>
121
+ <div className="mt-4">
122
+ <input
123
+ type="password"
124
+ value={manualToken}
125
+ onChange={event => setManualTokenInput(event.target.value)}
126
+ className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
127
+ placeholder="Enter Hugging Face token"
128
+ />
129
+ </div>
130
+ <div className="mt-4 flex items-center gap-3">
131
+ <button
132
+ type="submit"
133
+ className="px-4 py-2 rounded-md bg-blue-600 hover:bg-blue-500 text-sm text-white transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
134
+ disabled={authStatus === 'checking' || manualToken.trim() === ''}
135
+ >
136
+ Validate Token
137
+ </button>
138
+ {isAuthenticated && authToken === manualToken && (
139
+ <span className="text-xs text-emerald-400">Active token</span>
140
+ )}
141
+ </div>
142
+ {authError && (
143
+ <p className="mt-3 text-xs text-red-400">{authError}</p>
144
+ )}
145
+ </form>
146
+ </div>
147
+
148
+ <form onSubmit={handleSubmit} className="space-y-6">
149
+ <div className="grid grid-cols-1 gap-6 sm:grid-cols-2">
150
+ <div>
151
+ <div className="space-y-4">
152
+ <div>
153
+ <label htmlFor="TRAINING_FOLDER" className="block text-sm font-medium mb-2">
154
+ Training Folder Path
155
+ <div className="text-gray-500 text-sm ml-1">
156
+ We will store your training information here. Must be an absolute path. If blank, it will default
157
+ to the output folder in the project root.
158
+ </div>
159
+ </label>
160
+ <input
161
+ type="text"
162
+ id="TRAINING_FOLDER"
163
+ name="TRAINING_FOLDER"
164
+ value={settings.TRAINING_FOLDER}
165
+ onChange={handleChange}
166
+ className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
167
+ placeholder="Enter training folder path"
168
+ />
169
+ </div>
170
+
171
+ <div>
172
+ <label htmlFor="DATASETS_FOLDER" className="block text-sm font-medium mb-2">
173
+ Dataset Folder Path
174
+ <div className="text-gray-500 text-sm ml-1">
175
+ Where we store and find your datasets.{' '}
176
+ <span className="text-orange-800">
177
+ Warning: This software may modify datasets so it is recommended you keep a backup somewhere else
178
+ or have a dedicated folder for this software.
179
+ </span>
180
+ </div>
181
+ </label>
182
+ <input
183
+ type="text"
184
+ id="DATASETS_FOLDER"
185
+ name="DATASETS_FOLDER"
186
+ value={settings.DATASETS_FOLDER}
187
+ onChange={handleChange}
188
+ className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
189
+ placeholder="Enter datasets folder path"
190
+ />
191
+ </div>
192
+ </div>
193
+ </div>
194
+ <div>
195
+ <div className="space-y-4">
196
+ <h3 className="text-lg font-medium mb-4">Hugging Face Jobs (Cloud Training)</h3>
197
+
198
+ <div>
199
+ <label htmlFor="HF_JOBS_NAMESPACE" className="block text-sm font-medium mb-2">
200
+ HF Jobs Namespace (optional)
201
+ <div className="text-gray-500 text-sm ml-1">
202
+ Leave blank to default to the account associated with your Hugging Face token.
203
+ </div>
204
+ </label>
205
+ <input
206
+ type="text"
207
+ id="HF_JOBS_NAMESPACE"
208
+ name="HF_JOBS_NAMESPACE"
209
+ value={settings.HF_JOBS_NAMESPACE}
210
+ onChange={handleChange}
211
+ className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
212
+ placeholder="e.g. your-username or your-org"
213
+ />
214
+ </div>
215
+
216
+ <div>
217
+ <label htmlFor="HF_JOBS_DEFAULT_HARDWARE" className="block text-sm font-medium mb-2">
218
+ Default Hardware
219
+ <div className="text-gray-500 text-sm ml-1">
220
+ Default hardware configuration for cloud training jobs.
221
+ </div>
222
+ </label>
223
+ <select
224
+ id="HF_JOBS_DEFAULT_HARDWARE"
225
+ name="HF_JOBS_DEFAULT_HARDWARE"
226
+ value={settings.HF_JOBS_DEFAULT_HARDWARE}
227
+ onChange={(e) => setSettings(prev => ({ ...prev, HF_JOBS_DEFAULT_HARDWARE: e.target.value }))}
228
+ className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
229
+ >
230
+ <option value="cpu-basic">CPU Basic</option>
231
+ <option value="cpu-upgrade">CPU Upgrade</option>
232
+ <option value="t4-small">T4 Small</option>
233
+ <option value="t4-medium">T4 Medium</option>
234
+ <option value="l4x1">L4x1</option>
235
+ <option value="l4x4">L4x4</option>
236
+ <option value="a10g-small">A10G Small</option>
237
+ <option value="a10g-large">A10G Large</option>
238
+ <option value="a10g-largex2">A10G Large x2</option>
239
+ <option value="a10g-largex4">A10G Large x4</option>
240
+ <option value="a100-large">A100 Large</option>
241
+ <option value="v5e-1x1">TPU v5e-1x1</option>
242
+ <option value="v5e-2x2">TPU v5e-2x2</option>
243
+ <option value="v5e-2x4">TPU v5e-2x4</option>
244
+ </select>
245
+ </div>
246
+ </div>
247
+ </div>
248
+ </div>
249
+
250
+ <button
251
+ type="submit"
252
+ disabled={status === 'saving'}
253
+ className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
254
+ >
255
+ {status === 'saving' ? 'Saving...' : 'Save Settings'}
256
+ </button>
257
+
258
+ {status === 'success' && <p className="text-green-500 text-center">Settings saved successfully!</p>}
259
+ {status === 'error' && <p className="text-red-500 text-center">Error saving settings. Please try again.</p>}
260
+ </form>
261
+ </MainContent>
262
+ </>
263
+ );
264
+ }
src/components/AddImagesModal.tsx ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+ import { createGlobalState } from 'react-global-hooks';
3
+ import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
4
+ import { FaUpload } from 'react-icons/fa';
5
+ import { useCallback, useState } from 'react';
6
+ import { useDropzone } from 'react-dropzone';
7
+ import { apiClient } from '@/utils/api';
8
+
9
+ export interface AddImagesModalState {
10
+ datasetName: string;
11
+ onComplete?: () => void;
12
+ }
13
+
14
+ export const addImagesModalState = createGlobalState<AddImagesModalState | null>(null);
15
+
16
+ export const openImagesModal = (datasetName: string, onComplete: () => void) => {
17
+ addImagesModalState.set({ datasetName, onComplete });
18
+ };
19
+
20
+ export default function AddImagesModal() {
21
+ const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
22
+ const [uploadProgress, setUploadProgress] = useState<number>(0);
23
+ const [isUploading, setIsUploading] = useState<boolean>(false);
24
+ const open = addImagesModalInfo !== null;
25
+
26
+ const onCancel = () => {
27
+ if (!isUploading) {
28
+ setAddImagesModalInfo(null);
29
+ }
30
+ };
31
+
32
+ const onDone = () => {
33
+ if (addImagesModalInfo?.onComplete && !isUploading) {
34
+ addImagesModalInfo.onComplete();
35
+ setAddImagesModalInfo(null);
36
+ }
37
+ };
38
+
39
+ const onDrop = useCallback(
40
+ async (acceptedFiles: File[]) => {
41
+ if (acceptedFiles.length === 0) return;
42
+
43
+ setIsUploading(true);
44
+ setUploadProgress(0);
45
+
46
+ const formData = new FormData();
47
+ acceptedFiles.forEach(file => {
48
+ formData.append('files', file);
49
+ });
50
+ formData.append('datasetName', addImagesModalInfo?.datasetName || '');
51
+
52
+ try {
53
+ await apiClient.post(`/api/datasets/upload`, formData, {
54
+ headers: {
55
+ 'Content-Type': 'multipart/form-data',
56
+ },
57
+ onUploadProgress: progressEvent => {
58
+ const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
59
+ setUploadProgress(percentCompleted);
60
+ },
61
+ timeout: 0, // Disable timeout
62
+ });
63
+
64
+ onDone();
65
+ } catch (error) {
66
+ console.error('Upload failed:', error);
67
+ } finally {
68
+ setIsUploading(false);
69
+ setUploadProgress(0);
70
+ }
71
+ },
72
+ [addImagesModalInfo],
73
+ );
74
+
75
+ const { getRootProps, getInputProps, isDragActive } = useDropzone({
76
+ onDrop,
77
+ accept: {
78
+ 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'],
79
+ 'video/*': ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv'],
80
+ 'text/*': ['.txt'],
81
+ },
82
+ multiple: true,
83
+ });
84
+
85
+ return (
86
+ <Dialog open={open} onClose={onCancel} className="relative z-10">
87
+ <DialogBackdrop
88
+ transition
89
+ className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
90
+ />
91
+
92
+ <div className="fixed inset-0 z-10 w-screen overflow-y-auto">
93
+ <div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
94
+ <DialogPanel
95
+ transition
96
+ className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
97
+ >
98
+ <div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
99
+ <div className="text-center">
100
+ <DialogTitle as="h3" className="text-base font-semibold text-gray-200 mb-4">
101
+ Add Images to: {addImagesModalInfo?.datasetName}
102
+ </DialogTitle>
103
+ <div className="w-full">
104
+ <div
105
+ {...getRootProps()}
106
+ className={`h-40 w-full flex flex-col items-center justify-center border-2 border-dashed rounded-lg cursor-pointer transition-colors duration-200
107
+ ${isDragActive ? 'border-blue-500 bg-blue-50/10' : 'border-gray-600'}`}
108
+ >
109
+ <input {...getInputProps()} />
110
+ <FaUpload className="size-8 mb-3 text-gray-400" />
111
+ <p className="text-sm text-gray-200 text-center">
112
+ {isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'}
113
+ </p>
114
+ </div>
115
+ {isUploading && (
116
+ <div className="mt-4">
117
+ <div className="w-full bg-gray-700 rounded-full h-2.5">
118
+ <div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
119
+ </div>
120
+ <p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
121
+ </div>
122
+ )}
123
+ </div>
124
+ </div>
125
+ </div>
126
+ <div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
127
+ <button
128
+ type="button"
129
+ onClick={onDone}
130
+ disabled={isUploading}
131
+ className={`inline-flex w-full justify-center rounded-md bg-slate-600 px-3 py-2 text-sm font-semibold text-white shadow-xs sm:ml-3 sm:w-auto
132
+ ${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
133
+ >
134
+ Done
135
+ </button>
136
+ <button
137
+ type="button"
138
+ data-autofocus
139
+ onClick={onCancel}
140
+ disabled={isUploading}
141
+ className={`mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0
142
+ ${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
143
+ >
144
+ Cancel
145
+ </button>
146
+ </div>
147
+ </DialogPanel>
148
+ </div>
149
+ </div>
150
+ </Dialog>
151
+ );
152
+ }
src/components/AddSingleImageModal.tsx ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+ import { createGlobalState } from 'react-global-hooks';
3
+ import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
4
+ import { FaUpload } from 'react-icons/fa';
5
+ import { useCallback, useState } from 'react';
6
+ import { useDropzone } from 'react-dropzone';
7
+ import { apiClient } from '@/utils/api';
8
+
9
+ export interface AddSingleImageModalState {
10
+
11
+ onComplete?: (imagePath: string|null) => void;
12
+ }
13
+
14
+ export const addSingleImageModalState = createGlobalState<AddSingleImageModalState | null>(null);
15
+
16
+ export const openAddImageModal = (onComplete: (imagePath: string|null) => void) => {
17
+ addSingleImageModalState.set({onComplete });
18
+ };
19
+
20
+ export default function AddSingleImageModal() {
21
+ const [addSingleImageModalInfo, setAddSingleImageModalInfo] = addSingleImageModalState.use();
22
+ const [uploadProgress, setUploadProgress] = useState<number>(0);
23
+ const [isUploading, setIsUploading] = useState<boolean>(false);
24
+ const open = addSingleImageModalInfo !== null;
25
+
26
+ const onCancel = () => {
27
+ if (!isUploading) {
28
+ setAddSingleImageModalInfo(null);
29
+ }
30
+ };
31
+
32
+ const onDone = (imagePath: string|null) => {
33
+ if (addSingleImageModalInfo?.onComplete && !isUploading) {
34
+ addSingleImageModalInfo.onComplete(imagePath);
35
+ setAddSingleImageModalInfo(null);
36
+ }
37
+ };
38
+
39
+ const onDrop = useCallback(
40
+ async (acceptedFiles: File[]) => {
41
+ if (acceptedFiles.length === 0) return;
42
+
43
+ setIsUploading(true);
44
+ setUploadProgress(0);
45
+
46
+ const formData = new FormData();
47
+ acceptedFiles.forEach(file => {
48
+ formData.append('files', file);
49
+ });
50
+
51
+ try {
52
+ const resp = await apiClient.post(`/api/img/upload`, formData, {
53
+ headers: {
54
+ 'Content-Type': 'multipart/form-data',
55
+ },
56
+ onUploadProgress: progressEvent => {
57
+ const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
58
+ setUploadProgress(percentCompleted);
59
+ },
60
+ timeout: 0, // Disable timeout
61
+ });
62
+ console.log('Upload successful:', resp.data);
63
+
64
+ onDone(resp.data.files[0] || null);
65
+ } catch (error) {
66
+ console.error('Upload failed:', error);
67
+ } finally {
68
+ setIsUploading(false);
69
+ setUploadProgress(0);
70
+ }
71
+ },
72
+ [addSingleImageModalInfo],
73
+ );
74
+
75
+ const { getRootProps, getInputProps, isDragActive } = useDropzone({
76
+ onDrop,
77
+ accept: {
78
+ 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'],
79
+ },
80
+ multiple: false,
81
+ });
82
+
83
+ return (
84
+ <Dialog open={open} onClose={onCancel} className="relative z-10">
85
+ <DialogBackdrop
86
+ transition
87
+ className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
88
+ />
89
+
90
+ <div className="fixed inset-0 z-10 w-screen overflow-y-auto">
91
+ <div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
92
+ <DialogPanel
93
+ transition
94
+ className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
95
+ >
96
+ <div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
97
+ <div className="text-center">
98
+ <DialogTitle as="h3" className="text-base font-semibold text-gray-200 mb-4">
99
+ Add Control Image
100
+ </DialogTitle>
101
+ <div className="w-full">
102
+ <div
103
+ {...getRootProps()}
104
+ className={`h-40 w-full flex flex-col items-center justify-center border-2 border-dashed rounded-lg cursor-pointer transition-colors duration-200
105
+ ${isDragActive ? 'border-blue-500 bg-blue-50/10' : 'border-gray-600'}`}
106
+ >
107
+ <input {...getInputProps()} />
108
+ <FaUpload className="size-8 mb-3 text-gray-400" />
109
+ <p className="text-sm text-gray-200 text-center">
110
+ {isDragActive ? 'Drop the image here...' : 'Drag & drop an image here, or click to select one'}
111
+ </p>
112
+ </div>
113
+ {isUploading && (
114
+ <div className="mt-4">
115
+ <div className="w-full bg-gray-700 rounded-full h-2.5">
116
+ <div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
117
+ </div>
118
+ <p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
119
+ </div>
120
+ )}
121
+ </div>
122
+ </div>
123
+ </div>
124
+ <div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
125
+ <button
126
+ type="button"
127
+ data-autofocus
128
+ onClick={onCancel}
129
+ disabled={isUploading}
130
+ className={`mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0
131
+ ${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
132
+ >
133
+ Cancel
134
+ </button>
135
+ </div>
136
+ </DialogPanel>
137
+ </div>
138
+ </div>
139
+ </Dialog>
140
+ );
141
+ }
src/components/AuthWrapper.tsx ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useState, useEffect, useRef } from 'react';
4
+ import { apiClient, isAuthorizedState } from '@/utils/api';
5
+ import { createGlobalState } from 'react-global-hooks';
6
+
7
+ interface AuthWrapperProps {
8
+ authRequired: boolean;
9
+ children: React.ReactNode | React.ReactNode[];
10
+ }
11
+
12
+ export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) {
13
+ const [token, setToken] = useState('');
14
+ // start with true, and deauth if needed
15
+ const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use();
16
+ const [isLoading, setIsLoading] = useState(false);
17
+ const [error, setError] = useState('');
18
+ const [isBrowser, setIsBrowser] = useState(false);
19
+ const inputRef = useRef<HTMLInputElement>(null);
20
+
21
+ const isAuthorized = authRequired ? isAuthorizedGlobal : true;
22
+
23
+ // Set isBrowser to true when component mounts
24
+ useEffect(() => {
25
+ setIsBrowser(true);
26
+ // Get token from localStorage only after component has mounted
27
+ const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
28
+ setToken(storedToken);
29
+ checkAuth();
30
+ }, []);
31
+
32
+ // auto focus on input when not authorized
33
+ useEffect(() => {
34
+ if (isAuthorized) {
35
+ return;
36
+ }
37
+ setTimeout(() => {
38
+ if (inputRef.current) {
39
+ inputRef.current.focus();
40
+ }
41
+ }, 100);
42
+ }, [isAuthorized]);
43
+
44
+ const checkAuth = async () => {
45
+ // always get current stored token here to avoid state race conditions
46
+ const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
47
+ if (!authRequired || isLoading || currentToken === '') {
48
+ return;
49
+ }
50
+ setIsLoading(true);
51
+ setError('');
52
+ try {
53
+ const response = await apiClient.get('/api/auth');
54
+ if (response.data.isAuthenticated) {
55
+ setIsAuthorized(true);
56
+ } else {
57
+ setIsAuthorized(false);
58
+ setError('Invalid token. Please try again.');
59
+ }
60
+ } catch (err) {
61
+ setIsAuthorized(false);
62
+ console.log(err);
63
+ setError('Invalid token. Please try again.');
64
+ }
65
+ setIsLoading(false);
66
+ };
67
+
68
+ const handleSubmit = async (e: React.FormEvent) => {
69
+ e.preventDefault();
70
+ setError('');
71
+
72
+ if (!token.trim()) {
73
+ setError('Please enter your token');
74
+ return;
75
+ }
76
+
77
+ if (isBrowser) {
78
+ localStorage.setItem('AI_TOOLKIT_AUTH', token);
79
+ checkAuth();
80
+ }
81
+ };
82
+
83
+ if (isAuthorized) {
84
+ return <>{children}</>;
85
+ }
86
+
87
+ return (
88
+ <div className="flex min-h-screen bg-gray-900 text-gray-100 absolute top-0 left-0 right-0 bottom-0 scroll-auto">
89
+ {/* Left side - decorative or brand area */}
90
+ <div className="hidden lg:flex lg:w-1/2 bg-gray-800 flex-col justify-center items-center p-12">
91
+ <div className="mb-4">
92
+ {/* Replace with your own logo */}
93
+ <div className="flex items-center justify-center">
94
+ <img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
95
+ </div>
96
+ </div>
97
+ <h1 className="text-4xl mb-6">AI Toolkit</h1>
98
+ </div>
99
+
100
+ {/* Right side - login form */}
101
+ <div className="w-full lg:w-1/2 flex flex-col justify-center items-center p-8 sm:p-12">
102
+ <div className="w-full max-w-md">
103
+ <div className="lg:hidden flex justify-center mb-4">
104
+ {/* Mobile logo */}
105
+ <div className="flex items-center justify-center">
106
+ <img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
107
+ </div>
108
+ </div>
109
+
110
+ <h2 className="text-3xl text-center mb-2 lg:hidden">AI Toolkit</h2>
111
+
112
+ <form onSubmit={handleSubmit} className="space-y-6">
113
+ <div>
114
+ <label htmlFor="token" className="block text-sm font-medium text-gray-400 mb-2">
115
+ Password
116
+ </label>
117
+ <input
118
+ id="token"
119
+ name="token"
120
+ type="password"
121
+ autoComplete="off"
122
+ required
123
+ value={token}
124
+ ref={inputRef}
125
+ onChange={e => setToken(e.target.value)}
126
+ className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200"
127
+ placeholder="Enter your password"
128
+ />
129
+ <div className='text-gray-500 text-xs mt-2'>
130
+ The password is set with the environment variable AI_TOOLKIT_AUTH, the default is the super secure secret word "password"
131
+ </div>
132
+ </div>
133
+
134
+ {error && (
135
+ <div className="p-3 bg-red-900/50 border border-red-800 rounded-lg text-red-200 text-sm">{error}</div>
136
+ )}
137
+
138
+ <button
139
+ type="submit"
140
+ disabled={isLoading}
141
+ className="w-full py-3 px-4 bg-blue-600 hover:bg-blue-700 rounded-lg text-white font-medium focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 transition duration-200 flex items-center justify-center"
142
+ >
143
+ {isLoading ? (
144
+ <svg
145
+ className="animate-spin h-5 w-5 text-white"
146
+ xmlns="http://www.w3.org/2000/svg"
147
+ fill="none"
148
+ viewBox="0 0 24 24"
149
+ >
150
+ <circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
151
+ <path
152
+ className="opacity-75"
153
+ fill="currentColor"
154
+ d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
155
+ ></path>
156
+ </svg>
157
+ ) : (
158
+ 'Check Password'
159
+ )}
160
+ </button>
161
+ </form>
162
+ </div>
163
+ </div>
164
+ </div>
165
+ );
166
+ }
src/components/Card.tsx ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ interface CardProps {
2
+ title?: string;
3
+ children?: React.ReactNode;
4
+ }
5
+
6
+ const Card: React.FC<CardProps> = ({ title, children }) => {
7
+ return (
8
+ <section className="space-y-2 px-4 pb-4 pt-2 bg-gray-900 rounded-lg">
9
+ {title && <h2 className="text-lg mb-2 font-semibold uppercase text-gray-500">{title}</h2>}
10
+ {children ? children : null}
11
+ </section>
12
+ );
13
+ };
14
+
15
+ export default Card;