-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizer.py
More file actions
executable file
·135 lines (113 loc) · 6.44 KB
/
Copy pathtokenizer.py
File metadata and controls
executable file
·135 lines (113 loc) · 6.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import logging
import pathlib
import numpy as np
import sentencepiece
from transformers import AutoProcessor
import openpi.shared.download as download
class PaligemmaTokenizer:
def __init__(self, max_len: int = 48):
self._max_len = max_len
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ")
if state is not None:
# This is the Pi05 format, where the state is part of the discrete language input.
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
state_str = " ".join(map(str, discretized_state))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
tokens = self._tokenizer.encode(full_prompt, add_bos=True)
else:
# This is the Pi0 format, where the state is part of the continuous action expert input.
# tokenize "\n" separately as the "start of answer" token
tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
mask = [True] * tokens_len + padding
tokens = tokens + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
mask = [True] * self._max_len
return np.asarray(tokens), np.asarray(mask)
class FASTTokenizer:
def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"):
self._max_len = max_len
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
# Instantiate FAST tokenizer
self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
# Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
action_tokens = self._fast_tokenizer(actions[None])[0]
action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)
# Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
postfix_tokens = (
self._paligemma_tokenizer.encode("Action: ")
+ action_tokens_in_pg.tolist()
+ self._paligemma_tokenizer.encode("|", add_eos=True)
)
else:
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
return self._fast_tokenizer.decode(
[action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim
)[0]
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens