File size: 4,659 Bytes
90cbf22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14eb533
90cbf22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import { v } from 'convex/values';
import { query, internalMutation } from './_generated/server';
import Replicate, { WebhookEventType } from 'replicate';
import { httpAction, internalAction } from './_generated/server';
import { internal, api } from './_generated/api';

function client(): Replicate {
  const replicate = new Replicate({
    auth: process.env.REPLICATE_API_TOKEN || '',
  });
  return replicate;
}

function replicateAvailable(): boolean {
  return !!process.env.REPLICATE_API_TOKEN;
}

export const insertMusic = internalMutation({
  args: { storageId: v.string(), type: v.union(v.literal('background'), v.literal('player')) },
  handler: async (ctx, args) => {
    await ctx.db.insert('music', {
      storageId: args.storageId,
      type: args.type,
    });
  },
});

export const getBackgroundMusic = query({
  handler: async (ctx) => {
    const music = await ctx.db
      .query('music')
      .filter((entry) => entry.eq(entry.field('type'), 'background'))
      .order('desc')
      .first();
    if (!music) {
      return '/assets/background.mp3';
    }
    const url = await ctx.storage.getUrl(music.storageId);
    if (!url) {
      throw new Error(`Invalid storage ID: ${music.storageId}`);
    }
    return url;
  },
});

export const enqueueBackgroundMusicGeneration = internalAction({
  handler: async (ctx): Promise<void> => {
    if (!replicateAvailable()) {
      return;
    }
    const worldStatus = await ctx.runQuery(api.world.defaultWorldStatus);
    if (!worldStatus) {
      console.log('No active default world, returning.');
      return;
    }
    // TODO: MusicGen-Large on Replicate only allows 30 seconds. Use MusicGen-Small for longer?
    await generateMusic('16-bit RPG adventure game with wholesome vibe', 30);
  },
});

export const handleReplicateWebhook = httpAction(async (ctx, request) => {
  const req = await request.json();
  if (req.id) {
    const prediction = await client().predictions.get(req.id);
    const response = await fetch(prediction.output);
    const music = await response.blob();
    const storageId = await ctx.storage.store(music);
    await ctx.runMutation(internal.music.insertMusic, { type: 'background', storageId });
  }
  return new Response();
});

enum MusicGenNormStrategy {
  Clip = 'clip',
  Loudness = 'loudness',
  Peak = 'peak',
  Rms = 'rms',
}

enum MusicGenFormat {
  wav = 'wav',
  mp3 = 'mp3',
}

/**
 *
 * @param prompt A description of the music you want to generate.
 * @param duration Duration of the generated audio in seconds.
 * @param webhook webhook URL for Replicate to call when @param webhook_events_filter is triggered
 * @param webhook_events_filter Array of event names to filter the webhook. See https://replicate.com/docs/reference/http#predictions.create--webhook_events_filter
 * @param normalization_strategy Strategy for normalizing audio.
 * @param top_k Reduces sampling to the k most likely tokens.
 * @param top_p Reduces sampling to tokens with cumulative probability of p. When set to `0` (default), top_k sampling is used.
 * @param temperature Controls the 'conservativeness' of the sampling process. Higher temperature means more diversity.
 * @param classifer_free_gudance Increases the influence of inputs on the output. Higher values produce lower-varience outputs that adhere more closely to inputs.
 * @param output_format Output format for generated audio. See @
 * @param seed Seed for random number generator. If None or -1, a random seed will be used.
 * @returns object containing metadata of the prediction with ID to fetch once result is completed
 */
export async function generateMusic(
  prompt: string,
  duration: number,
  webhook: string = process.env.CONVEX_SITE_URL + '/replicate_webhook' || '',
  webhook_events_filter: [WebhookEventType] = ['completed'],
  normalization_strategy: MusicGenNormStrategy = MusicGenNormStrategy.Peak,
  output_format: MusicGenFormat = MusicGenFormat.mp3,
  top_k = 250,
  top_p = 0,
  temperature = 1,
  classifer_free_gudance = 3,
  seed = -1,
  model_version = 'large',
) {
  if (!replicateAvailable()) {
    throw new Error('Replicate API token not set');
  }
  return await client().predictions.create({
    // https://replicate.com/facebookresearch/musicgen/versions/7a76a8258b23fae65c5a22debb8841d1d7e816b75c2f24218cd2bd8573787906
    version: '7a76a8258b23fae65c5a22debb8841d1d7e816b75c2f24218cd2bd8573787906',
    input: {
      model_version,
      prompt,
      duration,
      normalization_strategy,
      top_k,
      top_p,
      temperature,
      classifer_free_gudance,
      output_format,
      seed,
    },
    webhook,
    webhook_events_filter,
  });
}