Spaces:
Runtime error
Runtime error
Commit
·
25e0f62
1
Parent(s):
c654b20
add search type radio
Browse files
app.py
CHANGED
|
@@ -33,7 +33,7 @@ model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
|
|
| 33 |
filename="alpaca_train_400_epoch.pt"), map_location=device))
|
| 34 |
model.eval()
|
| 35 |
|
| 36 |
-
def respond(input):
|
| 37 |
model.eval()
|
| 38 |
src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
|
| 39 |
src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
|
|
@@ -46,18 +46,23 @@ def respond(input):
|
|
| 46 |
trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
|
| 47 |
|
| 48 |
out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
|
| 61 |
|
| 62 |
-
iface = gr.Interface(fn=respond,
|
|
|
|
|
|
|
| 63 |
iface.launch()
|
|
|
|
| 33 |
filename="alpaca_train_400_epoch.pt"), map_location=device))
|
| 34 |
model.eval()
|
| 35 |
|
| 36 |
+
def respond(search_type, input):
|
| 37 |
model.eval()
|
| 38 |
src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
|
| 39 |
src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
|
|
|
|
| 46 |
trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
|
| 47 |
|
| 48 |
out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
|
| 49 |
+
if search_type == "Greedy":
|
| 50 |
+
out = torch.nn.functional.softmax(out, dim=-1)
|
| 51 |
+
val, ix = out[:, -1].data.topk(1)
|
| 52 |
+
|
| 53 |
+
outputs[i] = ix[0][0]
|
| 54 |
+
if ix[0][0] == vocab_token_dict['<eos>']:
|
| 55 |
+
break
|
| 56 |
+
else:
|
| 57 |
+
out = torch.nn.functional.softmax(out, dim=-1)[:, -1].squeeze().detach().numpy()
|
| 58 |
+
ix = np.random.choice(np.arange(len(out)), 1, p=out)
|
| 59 |
+
|
| 60 |
+
outputs[i] = ix[0]
|
| 61 |
+
if ix[0] == vocab_token_dict['<eos>']:
|
| 62 |
+
break
|
| 63 |
return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
|
| 64 |
|
| 65 |
+
iface = gr.Interface(fn=respond,
|
| 66 |
+
inputs=[gr.Radio(["Greedy", "Probabilistic"], label="Search Type"), "text"],
|
| 67 |
+
outputs="text")
|
| 68 |
iface.launch()
|