diff --git a/huggingface_dpo/function.yaml b/huggingface_dpo/function.yaml new file mode 100644 index 00000000..c3593fa6 --- /dev/null +++ b/huggingface_dpo/function.yaml @@ -0,0 +1,371 @@ +kind: job +metadata: + name: huggingface-dpo-trainer + tag: '' + hash: 584b20584f58bfa89225b6999e6b55ad017dd87a + project: '' + labels: + author: pgw + categories: + - machine-learning + - model-training +spec: + command: '' + args: [] + image: mlrun/mlrun + build: + functionSourceCode: IyBDb3B5cmlnaHQgMjAyMyBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMAojCiMgVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQojIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuICJBUyBJUyIgQkFTSVMsCiMgV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuCiMgU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAojIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLgoKaW1wb3J0IGltcG9ydGxpYgppbXBvcnQgb3MKaW1wb3J0IHNodXRpbAppbXBvcnQgdGVtcGZpbGUKaW1wb3J0IHppcGZpbGUKZnJvbSBhYmMgaW1wb3J0IEFCQwpmcm9tIHR5cGluZyBpbXBvcnQgRGljdCwgTGlzdCwgVHVwbGUsIFVuaW9uCgppbXBvcnQgbWxydW4KaW1wb3J0IG51bXB5IGFzIG5wCmltcG9ydCBwYW5kYXMgYXMgcGQKaW1wb3J0IHBlZnQKaW1wb3J0IHRvcmNoCmltcG9ydCB0cmFuc2Zvcm1lcnMKZnJvbSBkYXRhc2V0cyBpbXBvcnQgRGF0YXNldCwgbG9hZF9kYXRhc2V0CmZyb20gbWxydW4uYXJ0aWZhY3RzLm1hbmFnZXIgaW1wb3J0IEFydGlmYWN0LCBQbG90bHlBcnRpZmFjdApmcm9tIG1scnVuLmRhdGFzdG9yZSBpbXBvcnQgaXNfc3RvcmVfdXJpCmZyb20gbWxydW4uZnJhbWV3b3Jrcy5fY29tbW9uIGltcG9ydCBDb21tb25UeXBlcywgTUxSdW5JbnRlcmZhY2UKZnJvbSBtbHJ1bi51dGlscyBpbXBvcnQgbG9nZ2VyCmZyb20gdHJsIGltcG9ydCBEUE9UcmFpbmVyCmZyb20gcGVmdCBpbXBvcnQgTG9yYUNvbmZpZywgUGVmdE1vZGVsLCBnZXRfcGVmdF9tb2RlbCwgcHJlcGFyZV9tb2RlbF9mb3Jfa2JpdF90cmFpbmluZwpmcm9tIHBsb3RseSBpbXBvcnQgZ3JhcGhfb2JqZWN0cyBhcyBnbwpmcm9tIHRyYW5zZm9ybWVycyBpbXBvcnQgKAogICAgQXV0b01vZGVsRm9yQ2F1c2FsTE0sCiAgICBBdXRvVG9rZW5pemVyLAogICAgQml0c0FuZEJ5dGVzQ29uZmlnLAogICAgRGF0YUNvbGxhdG9yRm9yTGFuZ3VhZ2VNb2RlbGluZywKICAgIFByZVRyYWluZWRNb2RlbCwKICAgIFByZVRyYWluZWRUb2tlbml6ZXIsCiAgICBUcmFpbmVyQ2FsbGJhY2ssCiAgICBUcmFpbmVyQ29udHJvbCwKICAgIFRyYWluZXJTdGF0ZSwKICAgIFRyYWluaW5nQXJndW1lbnRzLAopCgoKY2xhc3MgQ29uZmlnS2V5czoKICAgIGRlZXBzcGVlZCA9ICJkZWVwc3BlZWQiCiAgICBxdWFudGl6YXRpb24gPSAicXVhbnRpemF0aW9uIgogICAgdHJhaW5pbmcgPSAidHJhaW5pbmciCiAgICB0b2tlbml6ZXJfcHJldHJhaW5lZCA9ICJ0b2tlbml6ZXJfcHJldHJhaW5lZCIKICAgIG1vZGVsX3ByZXRyYWluZWQgPSAibW9kZWxfcHJldHJhaW5lZCIKICAgIHBlZnRfY29uZmlnID0gInBlZnQiCiAgICBkYXRhX2NvbGxhdG9yID0gImRhdGFfY29sbGF0b3IiCiAgICBiZXRhID0gImJldGEiCgoKIyAtLS0tLS0tLS0tLS0tLS0tLS0tLS0tZnJvbSBNTFJVTi0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tCmNsYXNzIEhGVHJhaW5lck1MUnVuSW50ZXJmYWNlKE1MUnVuSW50ZXJmYWNlLCBBQkMpOgogICAgIiIiCiAgICBUaGlzIGlzIHRlbXBvcmFyeSBhbmQgd2lsbCBiZSBidWlsdCBpbiBtbHJ1biAxLjUuMAogICAgSW50ZXJmYWNlIGZvciBhZGRpbmcgTUxSdW4gZmVhdHVyZXMgZm9yIHRlbnNvcmZsb3cga2VyYXMgQVBJLgogICAgIiIiCgogICAgIyBNTFJ1bnMgY29udGV4dCBkZWZhdWx0IG5hbWU6CiAgICBERUZBVUxUX0NPTlRFWFRfTkFNRSA9ICJtbHJ1bi1odWdnaW5nZmFjZSIKCiAgICAjIEF0dHJpYnV0ZXMgdG8gcmVwbGFjZSBzbyB0aGUgTUxSdW4gaW50ZXJmYWNlIHdpbGwgYmUgZnVsbHkgZW5hYmxlZC4KICAgIF9SRVBMQUNFRF9NRVRIT0RTID0gWwogICAgICAgICJ0cmFpbiIsCiAgICAgICAgIyAiZXZhbHVhdGUiCiAgICBdCgogICAgQGNsYXNzbWV0aG9kCiAgICBkZWYgYWRkX2ludGVyZmFjZSgKICAgICAgICBjbHMsCiAgICAgICAgb2JqOiBEUE9UcmFpbmVyLAogICAgICAgIHJlc3RvcmF0aW9uOiBDb21tb25UeXBlcy5NTFJ1bkludGVyZmFjZVJlc3RvcmF0aW9uVHlwZSA9IE5vbmUsCiAgICApOgogICAgICAgIHN1cGVyKEhGVHJhaW5lck1MUnVuSW50ZXJmYWNlLCBjbHMpLmFkZF9pbnRlcmZhY2UoCiAgICAgICAgICAgIG9iaj1vYmosIHJlc3RvcmF0aW9uPXJlc3RvcmF0aW9uCiAgICAgICAgKQoKICAgIEBjbGFzc21ldGhvZAogICAgZGVmIG1scnVuX3RyYWluKGNscyk6CiAgICAgICAgZGVmIHdyYXBwZXIoc2VsZjogRFBPVHJhaW5lciwgKmFyZ3MsICoqa3dhcmdzKToKICAgICAgICAgICAgIyBSZXN0b3JlIHRoZSBldmFsdWF0aW9uIG1ldGhvZCBhcyBgdHJhaW5gIHdpbGwgdXNlIGl0OgogICAgICAgICAgICAjIGNscy5fcmVzdG9yZV9hdHRyaWJ1dGUob2JqPXNlbGYsIGF0dHJpYnV0ZV9uYW1lPSJldmFsdWF0ZSIpCgogICAgICAgICAgICAjIENhbGwgdGhlIG9yaWdpbmFsIGZpdCBtZXRob2Q6CiAgICAgICAgICAgIHJlc3VsdCA9IHNlbGYub3JpZ2luYWxfdHJhaW4oKmFyZ3MsICoqa3dhcmdzKQoKICAgICAgICAgICAgIyBSZXBsYWNlIHRoZSBldmFsdWF0aW9uIG1ldGhvZCBhZ2FpbjoKICAgICAgICAgICAgIyBjbHMuX3JlcGxhY2VfZnVuY3Rpb24ob2JqPXNlbGYsIGZ1bmN0aW9uX25hbWU9ImV2YWx1YXRlIikKCiAgICAgICAgICAgIHJldHVybiByZXN1bHQKCiAgICAgICAgcmV0dXJuIHdyYXBwZXIKCgpjbGFzcyBNTFJ1bkNhbGxiYWNrKFRyYWluZXJDYWxsYmFjayk6CiAgICAiIiIKICAgIFRoaXMgaXMgdGVtcG9yYXJ5IGFuZCB3aWxsIGJlIGJ1aWx0IGluIG1scnVuIDEuNS4wCiAgICBDYWxsYmFjayBmb3IgY29sbGVjdGluZyBsb2dzIGR1cmluZyB0cmFpbmluZyAvIGV2YWx1YXRpb24gb2YgdGhlIGBUcmFpbmVyYCBBUEkuCiAgICAiIiIKCiAgICBkZWYgX19pbml0X18oCiAgICAgICAgc2VsZiwKICAgICAgICBjb250ZXh0OiBtbHJ1bi5NTENsaWVudEN0eCA9IE5vbmUsCiAgICAgICAgbW9kZWxfbmFtZTogc3RyID0gIm1vZGVsIiwKICAgICAgICB0YWc6IHN0ciA9ICIiLAogICAgICAgIGxhYmVsczogRGljdFtzdHIsIHN0cl0gPSBOb25lLAogICAgICAgIGV4dHJhX2RhdGE6IGRpY3QgPSBOb25lLAogICAgKToKICAgICAgICBzdXBlcigpLl9faW5pdF9fKCkKCiAgICAgICAgIyBTdG9yZSB0aGUgY29uZmlndXJhdGlvbnM6CiAgICAgICAgc2VsZi5fY29udGV4dCA9ICgKICAgICAgICAgICAgY29udGV4dAogICAgICAgICAgICBpZiBjb250ZXh0IGlzIG5vdCBOb25lCiAgICAgICAgICAgIGVsc2UgbWxydW4uZ2V0X29yX2NyZWF0ZV9jdHgoIi4vbWxydW4taHVnZ2luZ2ZhY2UiKQogICAgICAgICkKICAgICAgICBzZWxmLl9tb2RlbF9uYW1lID0gbW9kZWxfbmFtZQogICAgICAgIHNlbGYuX3RhZyA9IHRhZwogICAgICAgIHNlbGYuX2xhYmVscyA9IGxhYmVscwogICAgICAgIHNlbGYuX2V4dHJhX2RhdGEgPSBleHRyYV9kYXRhIGlmIGV4dHJhX2RhdGEgaXMgbm90IE5vbmUgZWxzZSB7fQoKICAgICAgICAjIFNldCB1cCB0aGUgbG9nZ2luZyBtb2RlOgogICAgICAgIHNlbGYuX2lzX3RyYWluaW5nID0gRmFsc2UKICAgICAgICBzZWxmLl9zdGVwczogTGlzdFtMaXN0W2ludF1dID0gW10KICAgICAgICBzZWxmLl9tZXRyaWNfc2NvcmVzOiBEaWN0W3N0ciwgTGlzdFtmbG9hdF1dID0ge30KICAgICAgICBzZWxmLl9hcnRpZmFjdHM6IERpY3Rbc3RyLCBBcnRpZmFjdF0gPSB7fQoKICAgIGRlZiBvbl9lcG9jaF9iZWdpbigKICAgICAgICBzZWxmLAogICAgICAgIGFyZ3M6IFRyYWluaW5nQXJndW1lbnRzLAogICAgICAgIHN0YXRlOiBUcmFpbmVyU3RhdGUsCiAgICAgICAgY29udHJvbDogVHJhaW5lckNvbnRyb2wsCiAgICAgICAgKiprd2FyZ3MsCiAgICApOgogICAgICAgIGlmIG5vdCBzdGF0ZS5pc193b3JsZF9wcm9jZXNzX3plcm86CiAgICAgICAgICAgIHJldHVybgogICAgICAgIHNlbGYuX3N0ZXBzLmFwcGVuZChbXSkKCiAgICBkZWYgb25fZXBvY2hfZW5kKAogICAgICAgIHNlbGYsCiAgICAgICAgYXJnczogVHJhaW5pbmdBcmd1bWVudHMsCiAgICAgICAgc3RhdGU6IFRyYWluZXJTdGF0ZSwKICAgICAgICBjb250cm9sOiBUcmFpbmVyQ29udHJvbCwKICAgICAgICAqKmt3YXJncywKICAgICk6CiAgICAgICAgaWYgbm90IHN0YXRlLmlzX3dvcmxkX3Byb2Nlc3NfemVybzoKICAgICAgICAgICAgcmV0dXJuCiAgICAgICAgc2VsZi5sb2dfbWV0cmljcygpCgogICAgZGVmIG9uX2xvZygKICAgICAgICBzZWxmLAogICAgICAgIGFyZ3M6IFRyYWluaW5nQXJndW1lbnRzLAogICAgICAgIHN0YXRlOiBUcmFpbmVyU3RhdGUsCiAgICAgICAgY29udHJvbDogVHJhaW5lckNvbnRyb2wsCiAgICAgICAgbG9nczogRGljdFtzdHIsIGZsb2F0XSA9IE5vbmUsCiAgICAgICAgKiprd2FyZ3MsCiAgICApOgogICAgICAgIGlmIG5vdCBzdGF0ZS5pc193b3JsZF9wcm9jZXNzX3plcm86CiAgICAgICAgICAgIHJldHVybgogICAgICAgIHJlY2VudF9sb2dzID0gc3RhdGUubG9nX2hpc3RvcnlbLTFdLmNvcHkoKQoKICAgICAgICByZWNlbnRfbG9ncy5wb3AoImVwb2NoIikKICAgICAgICBjdXJyZW50X3N0ZXAgPSBpbnQocmVjZW50X2xvZ3MucG9wKCJzdGVwIikpCiAgICAgICAgaWYgY3VycmVudF9zdGVwIG5vdCBpbiBzZWxmLl9zdGVwc1stMV06CiAgICAgICAgICAgIHNlbGYuX3N0ZXBzWy0xXS5hcHBlbmQoY3VycmVudF9zdGVwKQoKICAgICAgICBmb3IgbWV0cmljX25hbWUsIG1ldHJpY19zY29yZSBpbiByZWNlbnRfbG9ncy5pdGVtcygpOgogICAgICAgICAgICBpZiBtZXRyaWNfbmFtZS5zdGFydHN3aXRoKCJ0cmFpbl8iKToKICAgICAgICAgICAgICAgIGlmIG1ldHJpY19uYW1lLnNwbGl0KCJ0cmFpbl8iKVsxXSBub3QgaW4gc2VsZi5fbWV0cmljX3Njb3JlczoKICAgICAgICAgICAgICAgICAgICBzZWxmLl9tZXRyaWNfc2NvcmVzW21ldHJpY19uYW1lXSA9IFttZXRyaWNfc2NvcmVdCiAgICAgICAgICAgICAgICBjb250aW51ZQogICAgICAgICAgICBpZiBtZXRyaWNfbmFtZSBub3QgaW4gc2VsZi5fbWV0cmljX3Njb3JlczoKICAgICAgICAgICAgICAgIHNlbGYuX21ldHJpY19zY29yZXNbbWV0cmljX25hbWVdID0gW10KICAgICAgICAgICAgc2VsZi5fbWV0cmljX3Njb3Jlc1ttZXRyaWNfbmFtZV0uYXBwZW5kKG1ldHJpY19zY29yZSkKCiAgICBkZWYgb25fdHJhaW5fYmVnaW4oCiAgICAgICAgc2VsZiwKICAgICAgICBhcmdzOiBUcmFpbmluZ0FyZ3VtZW50cywKICAgICAgICBzdGF0ZTogVHJhaW5lclN0YXRlLAogICAgICAgIGNvbnRyb2w6IFRyYWluZXJDb250cm9sLAogICAgICAgICoqa3dhcmdzLAogICAgKToKICAgICAgICBpZiBub3Qgc3RhdGUuaXNfd29ybGRfcHJvY2Vzc196ZXJvOgogICAgICAgICAgICByZXR1cm4KICAgICAgICBzZWxmLl9pc190cmFpbmluZyA9IFRydWUKCiAgICBkZWYgb25fdHJhaW5fZW5kKAogICAgICAgIHNlbGYsCiAgICAgICAgYXJnczogVHJhaW5pbmdBcmd1bWVudHMsCiAgICAgICAgc3RhdGU6IFRyYWluZXJTdGF0ZSwKICAgICAgICBjb250cm9sOiBUcmFpbmVyQ29udHJvbCwKICAgICAgICBtb2RlbDogUHJlVHJhaW5lZE1vZGVsID0gTm9uZSwKICAgICAgICB0b2tlbml6ZXI6IFByZVRyYWluZWRUb2tlbml6ZXIgPSBOb25lLAogICAgICAgICoqa3dhcmdzLAogICAgKToKICAgICAgICBpZiBub3Qgc3RhdGUuaXNfd29ybGRfcHJvY2Vzc196ZXJvOgogICAgICAgICAgICByZXR1cm4KICAgICAgICBzZWxmLmxvZ19tZXRyaWNzKCkKCiAgICBkZWYgb25fZXZhbHVhdGUoCiAgICAgICAgc2VsZiwKICAgICAgICBhcmdzOiBUcmFpbmluZ0FyZ3VtZW50cywKICAgICAgICBzdGF0ZTogVHJhaW5lclN0YXRlLAogICAgICAgIGNvbnRyb2w6IFRyYWluZXJDb250cm9sLAogICAgICAgICoqa3dhcmdzLAogICAgKToKICAgICAgICBpZiBub3Qgc3RhdGUuaXNfd29ybGRfcHJvY2Vzc196ZXJvOgogICAgICAgICAgICByZXR1cm4KICAgICAgICBzZWxmLmxvZ19tZXRyaWNzKCkKCiAgICAgICAgaWYgc2VsZi5faXNfdHJhaW5pbmc6CiAgICAgICAgICAgIHJldHVybgoKICAgIGRlZiBsb2dfbWV0cmljcyhzZWxmKToKICAgICAgICBmb3IgbWV0cmljX25hbWUsIG1ldHJpY19zY29yZXMgaW4gc2VsZi5fbWV0cmljX3Njb3Jlcy5pdGVtcygpOgogICAgICAgICAgICBzZWxmLl9jb250ZXh0LmxvZ19yZXN1bHQoa2V5PW1ldHJpY19uYW1lLCB2YWx1ZT1tZXRyaWNfc2NvcmVzWy0xXSkKICAgICAgICAgICAgaWYgbGVuKG1ldHJpY19zY29yZXMpID4gMToKICAgICAgICAgICAgICAgIHNlbGYubG9nX21ldHJpY19wbG90KG5hbWU9bWV0cmljX25hbWUsIHNjb3Jlcz1tZXRyaWNfc2NvcmVzKQogICAgICAgIHNlbGYuX2NvbnRleHQuY29tbWl0KGNvbXBsZXRlZD1GYWxzZSkKCiAgICBkZWYgbG9nX21ldHJpY19wbG90KHNlbGYsIG5hbWU6IHN0ciwgc2NvcmVzOiBMaXN0W2Zsb2F0XSk6CiAgICAgICAgIyBJbml0aWFsaXplIGEgcGxvdGx5IGZpZ3VyZToKICAgICAgICBtZXRyaWNfZmlndXJlID0gZ28uRmlndXJlKCkKCiAgICAgICAgIyBBZGQgdGl0bGVzOgogICAgICAgIG1ldHJpY19maWd1cmUudXBkYXRlX2xheW91dCgKICAgICAgICAgICAgdGl0bGU9bmFtZS5jYXBpdGFsaXplKCkucmVwbGFjZSgiXyIsICIgIiksCiAgICAgICAgICAgIHhheGlzX3RpdGxlPSJTYW1wbGVzIiwKICAgICAgICAgICAgeWF4aXNfdGl0bGU9IlNjb3JlcyIsCiAgICAgICAgKQoKICAgICAgICAjIERyYXc6CiAgICAgICAgbWV0cmljX2ZpZ3VyZS5hZGRfdHJhY2UoCiAgICAgICAgICAgIGdvLlNjYXR0ZXIoeD1ucC5hcmFuZ2UobGVuKHNjb3JlcykpLCB5PXNjb3JlcywgbW9kZT0ibGluZXMiKQogICAgICAgICkKCiAgICAgICAgIyBDcmVhdGUgdGhlIHBsb3RseSBhcnRpZmFjdDoKICAgICAgICBpZiAiLyIgaW4gbmFtZToKICAgICAgICAgICAgbmFtZSA9ICJfIi5qb2luKG5hbWUuc3BsaXQoIi8iKSkKICAgICAgICBhcnRpZmFjdF9uYW1lID0gZiJ7bmFtZX1fcGxvdCIKICAgICAgICBhcnRpZmFjdCA9IFBsb3RseUFydGlmYWN0KGtleT1hcnRpZmFjdF9uYW1lLCBmaWd1cmU9bWV0cmljX2ZpZ3VyZSkKICAgICAgICBzZWxmLl9hcnRpZmFjdHNbYXJ0aWZhY3RfbmFtZV0gPSBzZWxmLl9jb250ZXh0LmxvZ19hcnRpZmFjdChhcnRpZmFjdCkKCgpkZWYgYXBwbHlfbWxydW4oCiAgICB0cmFpbmVyOiBEUE9UcmFpbmVyLAogICAgbW9kZWxfbmFtZTogc3RyID0gTm9uZSwKICAgIHRhZzogc3RyID0gIiIsCiAgICBjb250ZXh0OiBtbHJ1bi5NTENsaWVudEN0eCA9IE5vbmUsCiAgICBhdXRvX2xvZzogYm9vbCA9IFRydWUsCiAgICBsYWJlbHM6IERpY3Rbc3RyLCBzdHJdID0gTm9uZSwKICAgIGV4dHJhX2RhdGE6IGRpY3QgPSBOb25lLAogICAgKiprd2FyZ3MsCik6CiAgICAiIiIKICAgIFRoaXMgaXMgdGVtcG9yYXJ5IGFuZCB3aWxsIGJlIGJ1aWx0IGluIG1scnVuIDEuNS4wCiAgICAiIiIKICAgICMgR2V0IHBhcmFtZXRlcnMgZGVmYXVsdHM6CiAgICBpZiBjb250ZXh0IGlzIE5vbmU6CiAgICAgICAgY29udGV4dCA9IG1scnVuLmdldF9vcl9jcmVhdGVfY3R4KEhGVHJhaW5lck1MUnVuSW50ZXJmYWNlLkRFRkFVTFRfQ09OVEVYVF9OQU1FKQoKICAgIEhGVHJhaW5lck1MUnVuSW50ZXJmYWNlLmFkZF9pbnRlcmZhY2Uob2JqPXRyYWluZXIpCgogICAgaWYgYXV0b19sb2c6CiAgICAgICAgdHJhaW5lci5hZGRfY2FsbGJhY2soCiAgICAgICAgICAgIE1MUnVuQ2FsbGJhY2soCiAgICAgICAgICAgICAgICBjb250ZXh0PWNvbnRleHQsCiAgICAgICAgICAgICAgICBtb2RlbF9uYW1lPW1vZGVsX25hbWUsCiAgICAgICAgICAgICAgICB0YWc9dGFnLAogICAgICAgICAgICAgICAgbGFiZWxzPWxhYmVscywKICAgICAgICAgICAgICAgIGV4dHJhX2RhdGE9ZXh0cmFfZGF0YSwKICAgICAgICAgICAgKQogICAgICAgICkKCgojIC0tLS0tLS0tLS0tLS0tLS0tLS0tLS1lbmQgZnJvbSBNTFJVTi0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tCgoKZGVmIF9wcmludF90cmFpbmFibGVfcGFyYW1ldGVycyhtb2RlbCk6CiAgICAiIiIKICAgIFByaW50cyB0aGUgbnVtYmVyIG9mIHRyYWluYWJsZSBwYXJhbWV0ZXJzIGluIHRoZSBtb2RlbC4KICAgICIiIgogICAgdHJhaW5hYmxlX3BhcmFtcyA9IDAKICAgIGFsbF9wYXJhbSA9IDAKICAgIGZvciBfLCBwYXJhbSBpbiBtb2RlbC5uYW1lZF9wYXJhbWV0ZXJzKCk6CiAgICAgICAgYWxsX3BhcmFtICs9IHBhcmFtLm51bWVsKCkKICAgICAgICBpZiBwYXJhbS5yZXF1aXJlc19ncmFkOgogICAgICAgICAgICB0cmFpbmFibGVfcGFyYW1zICs9IHBhcmFtLm51bWVsKCkKICAgIHByaW50KAogICAgICAgIGYidHJhaW5hYmxlIHBhcmFtczoge3RyYWluYWJsZV9wYXJhbXN9IHx8IGFsbCBwYXJhbXM6IHthbGxfcGFyYW19IHx8IHRyYWluYWJsZSU6IgogICAgICAgIGYiIHsxMDAgKiB0cmFpbmFibGVfcGFyYW1zIC8gYWxsX3BhcmFtfSIKICAgICkKCgojIGRlZmF1bHQgY29uZmlncwojIHdpbGwgYmUgdXNlZCBpZiB1c2VyIHByb3ZpZGVzICJUcnVlIiB3aXRoIGNvbmZpZyBuYW1lIGFzIGlucHV0ClFVQU5USVpBVElPTl9DT05GSUcgPSB0cmFuc2Zvcm1lcnMuQml0c0FuZEJ5dGVzQ29uZmlnKAogICAgbG9hZF9pbl80Yml0PVRydWUsCiAgICBibmJfNGJpdF91c2VfZG91YmxlX3F1YW50PVRydWUsCiAgICBibmJfNGJpdF9xdWFudF90eXBlPSJuZjQiLAogICAgYm5iXzRiaXRfY29tcHV0ZV9kdHlwZT10b3JjaC5iZmxvYXQxNiwKKQoKUEVGVF9DT05GSUcgPSBwZWZ0LkxvcmFDb25maWcoCiAgICByPTE2LAogICAgbG9yYV9hbHBoYT0xNiwKICAgIHRhcmdldF9tb2R1bGVzPVsKICAgICAgICAicV9wcm9qIiwKICAgICAgICAia19wcm9qIiwKICAgICAgICAidl9wcm9qIiwKICAgICAgICAib19wcm9qIiwKICAgICAgICAiZ2F0ZV9wcm9qIiwKICAgICAgICAidXBfcHJvaiIsCiAgICAgICAgImRvd25fcHJvaiIsCiAgICBdLAogICAgbG9yYV9kcm9wb3V0PTAuMDUsCiAgICBiaWFzPSJub25lIiwKICAgIHRhc2tfdHlwZT0iQ0FVU0FMX0xNIiwKKQoKREVFUFNQRUVEX0NPTkZJRyA9IHsKICAgICJ0cmFpbl9taWNyb19iYXRjaF9zaXplX3Blcl9ncHUiOiAiYXV0byIsCiAgICAiZnAxNiI6IHsiZW5hYmxlZCI6IFRydWV9LAogICAgImF1dG90dW5pbmciOiB7CiAgICAgICAgImVuYWJsZWQiOiBUcnVlLAogICAgICAgICJhcmdfbWFwcGluZ3MiOiB7CiAgICAgICAgICAgICJ0cmFpbl9taWNyb19iYXRjaF9zaXplX3Blcl9ncHUiOiAiLS1wZXJfZGV2aWNlX3RyYWluX2JhdGNoX3NpemUiLAogICAgICAgICAgICAiZ3JhZGllbnRfYWNjdW11bGF0aW9uX3N0ZXBzICI6ICItLWdyYWRpZW50X2FjY3VtdWxhdGlvbl9zdGVwcyIsCiAgICAgICAgfSwKICAgIH0sCiAgICAiemVyb19vcHRpbWl6YXRpb24iOiB7CiAgICAgICAgInN0YWdlIjogMiwKICAgIH0sCn0KCgpkZWYgX3VwZGF0ZV9jb25maWcoc3JjOiBkaWN0LCBkc3Q6IGRpY3QpOgogICAgIiIiCiAgICB1cGRhdGUgY29uZmlncyBhY2NvcmRpbmcgdG8gdXNlciwgdGhpcyB3YXkgdGhlIHVzZXIgY2FuIGFkZC9tb2RpZnkgdmFsdWVzIGluIGRlZmF1bHQgY29uZmlncyBmb3IgZS5nLgoKICAgIGdvZXMgb3ZlciBhbGwgY29uZmlncyBhbmQgY29ycmVzcG9uZGluZyBwcmVmaXhlcywgY29sbGVjdCBhbGwgdGhlIGtleXMgZnJvbSB0aGUgZ2l2ZW4gZGljdCB0aGF0IHN0YXJ0CiAgICAgd2l0aCB0aGUgcHJlZml4IGFuZCBhZGQgdGhlbSB0byBhcHByb3ByaWF0ZSBjb25maWcKCiAgICA6cGFyYW0gc3JjOiBkaWN0IG9mIGFsbCBjYW5kaWRhdGUgdmFsdWVzIHRvIHVwZGF0ZSBkaWN0LgogICAgOnBhcmFtIGRzdDogZGljdCBjb250YWluaW5nIGFsbCBjb25maWdzIHRvIHVwZGF0ZS4KICAgICIiIgoKICAgIGZvciBjb25maWdfbmFtZSwgY29uZmlnIGluIGRzdC5pdGVtcygpOgoKICAgICAgICAjIElmIGdpdmVuIFRydWUgd2UgdXNlIGRlZmF1bHQgZGljdAogICAgICAgICMgQ2FuIGFsc28gYmUgRmFsc2Ugb3IgYSBjb25maWcgZGljdCBnaXZlbiBmcm9tIHVzZXIsIHNvIHdlIGNoZWNrIHNwZWNpZmljYWxseSBmbyBUcnVlCiAgICAgICAgaWYgY29uZmlnIGlzIFRydWUgYW5kIGNvbmZpZ19uYW1lID09ICJxdWFudGl6YXRpb24iOgogICAgICAgICAgICBjb25maWcgPSBRVUFOVElaQVRJT05fQ09ORklHCgogICAgICAgIGlmIGNvbmZpZyBpcyBUcnVlIGFuZCBjb25maWdfbmFtZSA9PSAicGVmdCI6CiAgICAgICAgICAgIGNvbmZpZyA9IFBFRlRfQ09ORklHCgogICAgICAgIGlmIGNvbmZpZyBpcyBUcnVlIGFuZCBjb25maWdfbmFtZSA9PSAiZGVlcHNwZWVkIjoKICAgICAgICAgICAgY29uZmlnID0gREVFUFNQRUVEX0NPTkZJRwoKICAgICAgICAjIGluIHNvbWUgY2FzZXMgd2UgY2FuIGdldCBhIGJvb2xlYW4gdmFsdWUsIGluIHRoYXQgY2FzZSBubyBuZWVkIHRvIGxvb2sgZm9yIGFyZ3MKICAgICAgICBpZiBpc2luc3RhbmNlKGNvbmZpZywgYm9vbCk6CiAgICAgICAgICAgIGNvbmZpZyA9IE5vbmUKCiAgICAgICAgZWxpZiBpc2luc3RhbmNlKGNvbmZpZywgZGljdCk6CiAgICAgICAgICAgIGZvciBrZXksIHZhbCBpbiBzcmMuaXRlbXMoKToKICAgICAgICAgICAgICAgIGlmIGtleS5zdGFydHN3aXRoKGNvbmZpZ19uYW1lKToKICAgICAgICAgICAgICAgICAgICBjb25maWdba2V5LnJlcGxhY2UoZiJ7Y29uZmlnX25hbWV9XyIsICIiKV0gPSB2YWwKCiAgICAgICAgIyB1cGRhdGUgYnkgY29uZmlnIG5hbWUKICAgICAgICBlbHNlOgogICAgICAgICAgICBmb3Iga2V5LCB2YWwgaW4gc3JjLml0ZW1zKCk6CiAgICAgICAgICAgICAgICBpZiBrZXkuc3RhcnRzd2l0aChjb25maWdfbmFtZSk6CiAgICAgICAgICAgICAgICAgICAgc2V0YXR0cihjb25maWcsIGtleS5yZXBsYWNlKGYie2NvbmZpZ19uYW1lfV8iLCAiIiksIHZhbCkKCiAgICAgICAgZHN0LnVwZGF0ZSh7Y29uZmlnX25hbWU6IGNvbmZpZ30pCgoKZGVmIF9nZXRfY2xhc3Nfb2JqZWN0KGNsYXNzX3BhdGg6IHN0cikgLT4gdHlwZToKICAgICIiIgogICAgZ2l2ZW4gYSBmdWxsIGNsYXNzIG5hbWUsIHRoaXMgZnVuY3Rpb24gcmV0dXJucyB0aGUgY29ycmVjdCBjbGFzcwoKICAgIDpwYXJhbSBjbGFzc19wYXRoOiBhIGZ1bGwgY2xhc3MgbmFtZSAoZXguICd0cmFuc2Zvcm1lcnMuQXV0b01vZGVsRm9yQ2F1c2FsTE0nKQoKICAgIDpyZXR1cm4gdGhlIHdhbnRlZCBjbGFzcyBvYmplY3QKICAgICIiIgogICAgbW9kdWxlX3BhdGgsIGNsYXNzX25hbWUgPSBjbGFzc19wYXRoLnJzcGxpdCgiLiIsIDEpCiAgICBtb2R1bGUgPSBpbXBvcnRsaWIuaW1wb3J0X21vZHVsZShtb2R1bGVfcGF0aCkKICAgIHJldHVybiBnZXRhdHRyKG1vZHVsZSwgY2xhc3NfbmFtZSkKCgpkZWYgX3NldF9tb2RlbF9hbmRfdG9rZW5pemVyKAogICAgbW9kZWw6IFVuaW9uW3N0ciwgTGlzdFtzdHJdXSwKICAgIHRva2VuaXplcjogVW5pb25bc3RyLCBMaXN0W3N0cl1dLAogICAgdGFzazogc3RyLAogICAgZnJhbWV3b3JrOiBzdHIsCiAgICBxdWFudGl6YXRpb25fY29uZmlnOiBkaWN0LAogICAgdXNlX2N1ZGE6IGJvb2wsCiAgICB0b2tlbml6ZXJfcHJldHJhaW5lZF9jb25maWcsCiAgICBtb2RlbF9wcmV0cmFpbmVkX2NvbmZpZywKICAgIGRldmljZV9tYXA6IHN0ciwKKToKICAgICIiIgogICAgZ2V0IHRoZSBjb3JyZWN0IG1vZGVsIGFuZCB0b2tlbml6ZXIgYWNjb3JkaW5nIHRvIGdpdmVuIHVzZXIgaW5wdXRzCgogICAgOnBhcmFtIG1vZGVsOiBhIHR1cGxlIGNvbnRhaW5pbmcgbW9kZWwgbmFtZSBhbmQgY2xhc3MsIG9yIHN0ciB3aXRoIG1vZGVsIG5hbWUgb3IgcGF0aAogICAgOnBhcmFtIHRva2VuaXplcjogYSB0dXBsZSBjb250YWluaW5nIHRva2VuaXplciBuYW1lIGFuZCBjbGFzcywgb3Igc3RyIHdpdGggdG9rZW5pemVyIG5hbWUgb3IgcGF0aAogICAgOnBhcmFtIHRhc2s6IGEgc3VwcG9ydGVkIG5scCB0YXNrLCB1c2VkIHRvIGNob29zZSBtb2RlbCBpZiBub3QgcHJvdmlkZWQKICAgIDpwYXJhbSBmcmFtZXdvcms6IHB0IG9yIHRmCiAgICA6cGFyYW0gcXVhbnRpemF0aW9uX2NvbmZpZzogcXVhbnRpemF0aW9uIGNvbmZpZyBvciBOb25lLCB0byBsb2FkIG1vZGVsIGluIGFwcHJvcHJpYXRlIHdheQogICAgOnBhcmFtIHVzZV9jdWRhOiB1c2UgZ3B1IG9yIG5vdAogICAgOnBhcmFtIHRva2VuaXplcl9wcmV0cmFpbmVkX2NvbmZpZzogY29uZmlnIHRvIGxvYWQgdGhlIHByZXRyYWluZWQgdG9rZW5pemVyCiAgICA6cGFyYW0gbW9kZWxfcHJldHJhaW5lZF9jb25maWc6IGNvbmZpZyB0byBsb2FkIHRoZSBwcmV0cmFpbmVkIG1vZGVsCiAgICA6cGFyYW0gZGV2aWNlX21hcDogYSBkZXZpY2UgbWFwIGZvciBtb2RlbCB0cmFpbmluZyBpZiB1c2luZyBudW1iZXIgb2YgZ3B1J3MKCiAgICA6cmV0dXJuczogbW9kZWwgYW5kIHRva2VuaXplcgogICAgIiIiCiAgICAjIGxvYWQgbW9kZWwgZnJvbSBzdG9yZQogICAgaWYgaXNpbnN0YW5jZShtb2RlbCwgc3RyKSBhbmQgaXNfc3RvcmVfdXJpKG1vZGVsKToKICAgICAgICBwYXNzCiAgICAgICAgIyBUT0RPOiBsb2FkIGJvdGggbW9kZWwgYW5kIHRva2VuaXplciBhbmQgcmV0dXJuLCBuZWVkIGd1eSdzIGhlbHAKCiAgICAjIGlmIGl0J3MgYSB0dXBsZSB0aGVtIHdlIGFzc3VtZSBpdCBjb250YWlucyBvZiBib3RoIG5hbWUgYW5kIGNsYXNzCiAgICBpZiBpc2luc3RhbmNlKG1vZGVsLCBsaXN0KToKICAgICAgICBtb2RlbF9uYW1lLCBtb2RlbF9jbGFzcyA9IG1vZGVsCiAgICAgICAgbW9kZWxfY2xhc3MgPSBfZ2V0X2NsYXNzX29iamVjdChtb2RlbF9jbGFzcykKCiAgICAjIGluIHRoZSBjYXNlIHdlIGRvbid0IGdldCB0aGUgbW9kZWwgY2xhc3Mgd2UgbmVlZCB0aGUgdGFzayBpbiBvcmRlciB0byBjaG9vc2UgdGhlIGNvcnJlY3QgbW9kZWwKICAgIGVsc2U6CiAgICAgICAgaWYgdGFzayBpcyBOb25lOgogICAgICAgICAgICBsb2dnZXIuZXJyb3IoInRhc2sgbXVzdCBiZSBjaG9zZW4gaW4gb3JkZXIgdG8gZGV0ZXJtaW5lIHRoZSBjb3JyZWN0IG1vZGVsIikKICAgICAgICAgICAgcmFpc2UgRXhjZXB0aW9uKAogICAgICAgICAgICAgICAgInRoaXMgZnVuY3Rpb24gcmVxdWlyZXMgZWl0aGVyIGEgc3VwcG9ydGVkIHRhc2sgb3IgYSBtb2RlbCBhbmQgbW9kZWwgY2xhc3MgdG8gYmUgY2hvc2VuIgogICAgICAgICAgICApCgogICAgICAgIF8sIGF2YWlsYWJsZV9jbGFzc2VzLCB0YXNrX29wdGlvbnMgPSB0cmFuc2Zvcm1lcnMucGlwZWxpbmVzLmNoZWNrX3Rhc2sodGFzaykKCiAgICAgICAgaWYgaXNpbnN0YW5jZShtb2RlbCwgc3RyKToKICAgICAgICAgICAgbW9kZWxfbmFtZSA9IG1vZGVsCgogICAgICAgICMgaWYgbW9kZWwgaXMgbm90IGdpdmVuLCB3ZSB0YWtlIHRoZSBkZWZhdWx0IG1vZGVsIGZvciB0aGUgZ2l2ZW4gdGFzawogICAgICAgIGVsc2U6CiAgICAgICAgICAgIG1vZGVsX25hbWUsIF8gPSB0cmFuc2Zvcm1lcnMucGlwZWxpbmVzLmdldF9kZWZhdWx0X21vZGVsX2FuZF9yZXZpc2lvbigKICAgICAgICAgICAgICAgIGF2YWlsYWJsZV9jbGFzc2VzLCBmcmFtZXdvcmssIHRhc2tfb3B0aW9ucwogICAgICAgICAgICApCiAgICAgICAgaWYgbm90IGF2YWlsYWJsZV9jbGFzc2VzLmdldChmcmFtZXdvcmssIHR1cGxlKCkpOgogICAgICAgICAgICBsb2dnZXIuZXJyb3IoCiAgICAgICAgICAgICAgICAiZ2l2ZW4gdGFzaydzIGRlZmF1bHQgbW9kZWwgaXMgbm90IHN1cHBvcnRlZCBpbiBzcGVjaWZpZWQgZnJhbWV3b3JrIgogICAgICAgICAgICApCiAgICAgICAgICAgIHJhaXNlIEV4Y2VwdGlvbigKICAgICAgICAgICAgICAgICJ0aGlzIGZ1bmN0aW9uIHJlcXVpcmVzIGVpdGhlciBhIHN1cHBvcnRlZCB0YXNrIG9yIGEgbW9kZWwgYW5kIG1vZGVsIGNsYXNzIHRvIGJlIGNob3NlbiIKICAgICAgICAgICAgKQoKICAgICAgICBtb2RlbF9jbGFzcyA9IGF2YWlsYWJsZV9jbGFzc2VzW2ZyYW1ld29ya11bMF0KCiAgICAjIGxvYWQgdGhlIHByZXRyYWluZWQgbW9kZWwKICAgIGlmIHVzZV9jdWRhOgogICAgICAgIGRldmljZV9tYXAgPSBkZXZpY2VfbWFwCiAgICBlbHNlOgogICAgICAgIGRldmljZV9tYXAgPSBOb25lCgogICAgbW9kZWwgPSBtb2RlbF9jbGFzcy5mcm9tX3ByZXRyYWluZWQoCiAgICAgICAgbW9kZWxfbmFtZSwKICAgICAgICBxdWFudGl6YXRpb25fY29uZmlnPXF1YW50aXphdGlvbl9jb25maWcsCiAgICAgICAgZGV2aWNlX21hcD1kZXZpY2VfbWFwLAogICAgICAgICoqbW9kZWxfcHJldHJhaW5lZF9jb25maWcsCiAgICApCgogICAgIyBJZiBxdWFudGl6YXRpb24gY29uZmlnIGlzIGdpdmVuIHdlIHdpbGwgbG9hZCBhIHF1YW50aXplZCBtb2RlbCwgaWYgbm90IGEgcmVndWxhciBvbmUKICAgIGlmIHF1YW50aXphdGlvbl9jb25maWc6CiAgICAgICAgbW9kZWwuZ3JhZGllbnRfY2hlY2twb2ludGluZ19lbmFibGUoKQogICAgICAgIG1vZGVsID0gcGVmdC5wcmVwYXJlX21vZGVsX2Zvcl9rYml0X3RyYWluaW5nKG1vZGVsKQoKICAgICMgaWYgbm90IHNwZWNpZmllZCB3ZSBjaG9vc2UgdGhlIGRlZmF1bHQgdG9rZW5pemVyIHRoYXQgY29ycmVzcG9uZGluZyB0byB0aGUgbW9kZWwKICAgIGlmIHRva2VuaXplciBpcyBOb25lOgogICAgICAgIHRva2VuaXplciA9IHRyYW5zZm9ybWVycy5BdXRvVG9rZW5pemVyLmZyb21fcHJldHJhaW5lZChtb2RlbF9uYW1lKQogICAgICAgIHJldHVybiBtb2RlbF9uYW1lLCBtb2RlbCwgdG9rZW5pemVyCgogICAgaWYgaXNpbnN0YW5jZSh0b2tlbml6ZXIsIHN0cik6CiAgICAgICAgdG9rZW5pemVyX25hbWUgPSB0b2tlbml6ZXIKICAgICAgICB0b2tlbml6ZXJfY2xhc3MgPSB0cmFuc2Zvcm1lcnMuQXV0b1Rva2VuaXplcgoKICAgICMgaWYgaXQncyBub3QgYSBzdHIgdGhlbiBpdCdzIGEgdHVwbGUgb2YgYm90aCBuYW1lIGFuZCBjbGFzcwogICAgZWxzZToKICAgICAgICB0b2tlbml6ZXJfbmFtZSwgdG9rZW5pemVyX2NsYXNzID0gdG9rZW5pemVyCiAgICAgICAgdG9rZW5pemVyX2NsYXNzID0gX2dldF9jbGFzc19vYmplY3QodG9rZW5pemVyX2NsYXNzKQoKICAgIHRva2VuaXplciA9IHRva2VuaXplcl9jbGFzcy5mcm9tX3ByZXRyYWluZWQoCiAgICAgICAgdG9rZW5pemVyX25hbWUsICoqdG9rZW5pemVyX3ByZXRyYWluZWRfY29uZmlnCiAgICApCgogICAgdG9rZW5pemVyLnBhZF90b2tlbiA9IHRva2VuaXplci5lb3NfdG9rZW4KCiAgICByZXR1cm4gbW9kZWxfbmFtZSwgbW9kZWwsIHRva2VuaXplcgoKCmRlZiBfZGF0YXNldF9sb2FkZXIoZGF0YXNldDogc3RyLCBpc190cmFpbjogYm9vbCA9IFRydWUsICoqa3dhcmdzKSAtPiBEYXRhc2V0OgogICAgIiIiCiAgICBsb2FkcyB0aGUgc3BlY2lmaWMgZGF0YXNldCBwcm92aWRlZCBieSB0aGUgdXNlcgoKICAgIDpwYXJhbSBkYXRhc2V0OiBuYW1lIG9yIHBhdGggb2YgZGF0YXNldCB0byBsb2FkCiAgICA6cGFyYW0gaXNfdHJhaW46IGJvb2wgdGhhdCBpbmRpY2F0ZXMgdGhlIHB1cnBvc2Ugb2YgdGhlIGRhdGFzZXQKICAgIDpwYXJhbSBrd2FyZ3M6IG90aGVyIGt3YXJncyBmb3IgbG9hZGluZyB0aGUgZGF0YXNldAoKICAgIDpyZXR1cm5zOiBsb2FkZWQgZGF0YXNldAogICAgIiIiCiAgICAjIGlmIHNwbGl0IGluIGt3YXJncyB0aGVuIHRoZSB1c2VyIGRlY2lkZXMgaG93IHRvIHNwbGl0IHRoZSBkYXRhc2V0CiAgICBpZiAic3BsaXQiIGluIGt3YXJnczoKICAgICAgICByZXR1cm4gbG9hZF9kYXRhc2V0KGRhdGFzZXQsICoqa3dhcmdzKQoKICAgICMgaWYgaXQncyBhIGRhdGFzZXQgZm9yIHRyYWluIHdlIHNwbGl0IHdpdGggdHJhaW4KICAgIGlmIGlzX3RyYWluOgogICAgICAgIHJldHVybiBsb2FkX2RhdGFzZXQoZGF0YXNldCwgc3BsaXQ9InRyYWluIiwgKiprd2FyZ3MpCgogICAgIyBpZiBpdCdzIGV2YWwgZGF0YXNldCwgdGhlbiBhIGxvdCBvZiBuYW1lcyBhcmUgYWNjZXB0YWJsZSBmb3IgdGhlIHNldCBhbmQgd2UgY2hlY2sgYWxsIG9mIHRoZW0KICAgIGRhdGFzZXQgPSBsb2FkX2RhdGFzZXQoZGF0YXNldCwgKiprd2FyZ3MpCiAgICBpZiAidGVzdCIgaW4gZGF0YXNldDoKICAgICAgICByZXR1cm4gZGF0YXNldC5nZXQoInRlc3QiKQogICAgZWxpZiAiZXZhbCIgaW4gZGF0YXNldDoKICAgICAgICByZXR1cm4gZGF0YXNldC5nZXQoImV2YWwiKQogICAgZWxpZiAidmFsaWRhdGlvbiIgaW4gZGF0YXNldDoKICAgICAgICByZXR1cm4gZGF0YXNldC5nZXQoInZhbGlkYXRpb24iKQogICAgcmV0dXJuIGRhdGFzZXQKCgpkZWYgX3ByZXBhcmVfZGF0YXNldCgKICAgIHRyYWluX2RhdGFzZXQ6IHN0ciwKICAgIGV2YWxfZGF0YXNldDogc3RyLAogICAgdHJhaW5fbG9hZF9kYXRhc2V0X2t3YXJncywKICAgIGV2YWxfbG9hZF9kYXRhc2V0X2t3YXJncywKKSAtPiAoRGF0YXNldCwgVW5pb25bRGF0YXNldCwgTm9uZV0pOgogICAgIiIiCiAgICBMb2FkcyB0aGUgdHJhaW4gYW5kIGV2YWwgZGF0YXNldHMgKGlmIHByb3ZpZGVkKSBwYXNzZXMgdGhlbSB0aHJvdWdoIHRoZSB0b2tlbml6ZXIgYW5kCiAgICByZXR1cm5zIHRoZW0gcmVhZHkgdG8gdXNlIGluIHRyYWluaW5nCgogICAgOnBhcmFtIHRyYWluX2RhdGFzZXQ6IHRoZSBuYW1lIG9yIHBhdGggdG8gdGhlIHRyYWluIGRhdGFzZXQKICAgIDpwYXJhbSBldmFsX2RhdGFzZXQ6IHRoZSBuYW1lIG9yIHBhdGggdG8gdGhlIGV2YWwgZGF0YXNldAogICAgOnBhcmFtIHRyYWluX2xvYWRfZGF0YXNldF9rd2FyZ3M6IGt3YXJncyBmb3IgZGF0YXNldCBsb2FkaW5nCiAgICA6cGFyYW0gZXZhbF9sb2FkX2RhdGFzZXRfa3dhcmdzOiBrd2FyZ3MgZm9yIGRhdGFzZXQgbG9hZGluZwoKICAgIDpyZXR1cm5zOiB0b2tlbml6ZWQgZGF0YXNldHMKICAgICIiIgoKICAgICMgTG9hZCBkYXRhc2V0cwogICAgIyBpZiBwcm92aWRlZCB0d28gcGF0aHMvbmFtZXMgd2UgbG9hZCBlYWNoIHNlcGFyYXRlbHkgdXNpbmcgZGVzaWduYXRlZCBmdW5jCiAgICBpZiBldmFsX2RhdGFzZXQ6CiAgICAgICAgdHJhaW5fZGF0YXNldCA9IF9kYXRhc2V0X2xvYWRlcigKICAgICAgICAgICAgZGF0YXNldD10cmFpbl9kYXRhc2V0LCBpc190cmFpbj1UcnVlLCAqKnRyYWluX2xvYWRfZGF0YXNldF9rd2FyZ3MKICAgICAgICApCiAgICAgICAgZXZhbF9kYXRhc2V0ID0gX2RhdGFzZXRfbG9hZGVyKAogICAgICAgICAgICBkYXRhc2V0PWV2YWxfZGF0YXNldCwgaXNfdHJhaW49RmFsc2UsICoqZXZhbF9sb2FkX2RhdGFzZXRfa3dhcmdzCiAgICAgICAgKQogICAgIyBpZiBvbmx5IG9uIHBhdGggaXMgZ2l2ZW4gdGhlbiB3ZSBtdXN0IGNoZWNrIGlmIGl0IGNvbnRhaW5zIGJvdGggZGF0YXNldCBvciBpZiBvbmx5IG9uZSBzaG91bGQgYmUgdXNlZAogICAgZWxzZToKICAgICAgICBkYXRhc2V0ID0gbG9hZF9kYXRhc2V0KHRyYWluX2RhdGFzZXQsICoqdHJhaW5fbG9hZF9kYXRhc2V0X2t3YXJncykKICAgICAgICBpZiAidHJhaW4iIGluIGRhdGFzZXQ6CiAgICAgICAgICAgIHRyYWluX2RhdGFzZXQgPSBkYXRhc2V0LmdldCgidHJhaW4iKQogICAgICAgICAgICBpZiAidGVzdCIgaW4gZGF0YXNldDoKICAgICAgICAgICAgICAgIGV2YWxfZGF0YXNldCA9IGRhdGFzZXQuZ2V0KCJ0ZXN0IikKICAgICAgICAgICAgZWxpZiAiZXZhbCIgaW4gZGF0YXNldDoKICAgICAgICAgICAgICAgIGV2YWxfZGF0YXNldCA9IGRhdGFzZXQuZ2V0KCJldmFsIikKICAgICAgICAgICAgZWxpZiAidmFsaWRhdGlvbiIgaW4gZGF0YXNldDoKICAgICAgICAgICAgICAgIGV2YWxfZGF0YXNldCA9IGRhdGFzZXQuZ2V0KCJ2YWxpZGF0aW9uIikKICAgICAgICAgICAgZWxzZToKICAgICAgICAgICAgICAgIHJldHVybiB0cmFpbl9kYXRhc2V0CiAgICAgICAgZWxzZToKICAgICAgICAgICAgbG9nZ2VyLmVycm9yKCJ0cmFpbiBkYXRhc2V0IGlzIG1hbmRhdG9yeSIpCiAgICAgICAgICAgIHJhaXNlIEtleUVycm9yKCJubyB0cmFpbiBkYXRhc2V0IGZvdW5kIGluIGdpdmVuIGRhdGFzZXQiKQoKICAgIHJldHVybiB0cmFpbl9kYXRhc2V0LCBldmFsX2RhdGFzZXQKCgpkZWYgZHBvX3RyYWluKAogICAgY29udGV4dDogbWxydW4uTUxDbGllbnRDdHgsCiAgICB0cmFpbl9kYXRhc2V0OiBVbmlvbltzdHIsIG1scnVuLmRhdGFzdG9yZS5EYXRhSXRlbV0sCiAgICBldmFsX2RhdGFzZXQ6IHN0ciA9IE5vbmUsCiAgICB0cmFpbl9sb2FkX2RhdGFzZXRfa3dhcmdzOiBkaWN0ID0ge30sCiAgICBldmFsX2xvYWRfZGF0YXNldF9rd2FyZ3M6IGRpY3QgPSB7fSwKICAgIG1vZGVsOiBVbmlvbltzdHIsIExpc3Rbc3RyXV0gPSAiaHVnZ2luZ2ZhY2UtbW9kZWwiLAogICAgdG9rZW5pemVyOiBVbmlvbltzdHIsIExpc3Rbc3RyXV0gPSBOb25lLAogICAgZGVlcHNwZWVkX2NvbmZpZzogVW5pb25bZGljdCwgYm9vbF0gPSBGYWxzZSwKICAgIHF1YW50aXphdGlvbl9jb25maWc6IFVuaW9uW2RpY3QsIGJvb2xdID0gRmFsc2UsCiAgICBwZWZ0X2NvbmZpZzogVW5pb25bZGljdCwgYm9vbF0gPSBGYWxzZSwKICAgIGJldGE6IFVuaW9uW2Zsb2F0LCBib29sXSA9IEZhbHNlLAogICAgdHJhaW5pbmdfY29uZmlnOiBkaWN0ID0ge30sCiAgICBtb2RlbF9wcmV0cmFpbmVkX2NvbmZpZzogZGljdCA9IHt9LAogICAgdG9rZW5pemVyX3ByZXRyYWluZWRfY29uZmlnOiBkaWN0ID0ge30sCiAgICBkYXRhX2NvbGxhdG9yX2NvbmZpZzogZGljdCA9IHt9LAogICAgdGFzazogc3RyID0gInRleHQtZ2VuZXJhdGlvbiIsCiAgICB1c2VfY3VkYTogYm9vbCA9IFRydWUsCiAgICBmcmFtZXdvcms6IHN0ciA9ICJwdCIsCiAgICBkZXZpY2VfbWFwOiBzdHIgPSAiYXV0byIsCiAgICAqKmt3YXJncywKKToKICAgICIiIgogICAgRm9ybSBhIGRwbyB0cmFpbmluZyBqb2IgdG8gZG8gbGxtIGFsaWdubWVudAogICAgIFRoZSBmdW5jdGlvbiB0YWtlcyB2YXJpb3VzIGNvbmZpZ3VyYXRpb24gcGFyYW1ldGVycyB0byBjdXN0b21pemUgdGhlIHRyYWluaW5nIHByb2Nlc3MKICAgICBhbmQgYWRhcHQgdGhlIG1vZGVsIHRvIHNwZWNpZmljIHRhc2tzIHVzaW5nIGEgcHJvdmlkZWQgZGF0YXNldC4KCiAgICA6cGFyYW0gY29udGV4dDogbWxydW4gY29udGV4dCBpbiBvcmRlciB0byBsb2cgdHJhaW5lZCBtb2RlbAogICAgOnBhcmFtIHRyYWluX2RhdGFzZXQ6IFRoZSB0cmFpbiBkYXRhc2V0IHVzZWQgZm9yIGZpbmUtdHVuaW5nIHRoZSBsYW5ndWFnZSBtb2RlbC4KICAgIDpwYXJhbSBldmFsX2RhdGFzZXQ6IFRoZSBldmFsIGRhdGFzZXQgdXNlZCBmb3IgZXZhbHVhdGUgdGhlIGxhbmd1YWdlIG1vZGVsIGR1cmluZyB0cmFpbmluZy4KICAgIDpwYXJhbSB0cmFpbl9sb2FkX2RhdGFzZXRfa3dhcmdzOiBrd2FyZ3MgZm9yIGRhdGFzZXQgbG9hZGluZwogICAgOnBhcmFtIGV2YWxfbG9hZF9kYXRhc2V0X2t3YXJnczoga3dhcmdzIGZvciBkYXRhc2V0IGxvYWRpbmcKICAgIDpwYXJhbSBtb2RlbDogYSB0dXBsZSBjb250YWluaW5nIG1vZGVsIG5hbWUgYW5kIGNsYXNzLCBvciBzdHIgd2l0aCBtb2RlbCBuYW1lIG9yIHBhdGgKICAgIDpwYXJhbSB0b2tlbml6ZXI6IGEgdHVwbGUgY29udGFpbmluZyB0b2tlbml6ZXIgbmFtZSBhbmQgY2xhc3MsIG9yIHN0ciB3aXRoIHRva2VuaXplciBuYW1lIG9yIHBhdGgKICAgIDpwYXJhbSBkZWVwc3BlZWRfY29uZmlnOiBDb25maWd1cmF0aW9uIG9wdGlvbnMgZm9yIERlZXBTcGVlZCAob3B0aW9uYWwpLgogICAgOnBhcmFtIHF1YW50aXphdGlvbl9jb25maWc6IENvbmZpZ3VyYXRpb24gb3B0aW9ucyBmb3IgbW9kZWwgcXVhbnRpemF0aW9uIChvcHRpb25hbCkuCiAgICA6cGFyYW0gcGVmdF9jb25maWc6IENvbmZpZ3VyYXRpb24gb3B0aW9ucyBmb3IgTG93LVJhbmsgQXBwcm94aW1hdGlvbiAoTG9SQSkgKG9wdGlvbmFsKS4KICAgIDpwYXJhbSBiZXRhOiBzdXBlciBwYXJhbWV0ZXIgb2YgS0wgZGl2ZXJnZW5jZQogICAgOnBhcmFtIHRyYWluaW5nX2NvbmZpZzogQ29uZmlndXJhdGlvbiBvcHRpb25zIHNwZWNpZmljIHRvIHRoZSBmaW5lLXR1bmluZyB0cmFpbmluZyBwcm9jZXNzIChvcHRpb25hbCkuCiAgICA6cGFyYW0gbW9kZWxfcHJldHJhaW5lZF9jb25maWc6IGNvbmZpZyB0byBsb2FkIHRoZSBwcmV0cmFpbmVkIG1vZGVsCiAgICA6cGFyYW0gdG9rZW5pemVyX3ByZXRyYWluZWRfY29uZmlnOiBjb25maWcgdG8gbG9hZCB0aGUgcHJldHJhaW5lZCB0b2tlbml6ZXIKICAgIDpwYXJhbSBkYXRhX2NvbGxhdG9yX2NvbmZpZzogQ29uZmlndXJhdGlvbiBvcHRpb25zIGZvciBkYXRhIGNvbGxhdGlvbiBkdXJpbmcgdHJhaW5pbmcgKG9wdGlvbmFsKS4KICAgIDpwYXJhbSB0YXNrOiBBIGRlc2NyaXB0aW9uIG9mIHRoZSBzcGVjaWZpYyB0YXNrIHRoZSBtb2RlbCBpcyBiZWluZyBmaW5lLXR1bmVkIGZvci4KICAgIDpwYXJhbSB1c2VfY3VkYTogdXNlIGdwdSBvciBub3QKICAgIDpwYXJhbSBmcmFtZXdvcms6IHB0IG90IHRmCiAgICA6cGFyYW0ga3dhcmdzOiBBZGRpdGlvbmFsIGtleXdvcmQgYXJndW1lbnRzLgogICAgIiIiCgogICAgIyBMb29rIGZvciB1cGRhdGVzIHRvIGNvbmZpZ3MgZ2l2ZW4gaW4ga3dhcmdzCiAgICBjb25maWdzID0gewogICAgICAgIENvbmZpZ0tleXMuZGVlcHNwZWVkOiBkZWVwc3BlZWRfY29uZmlnLAogICAgICAgIENvbmZpZ0tleXMucXVhbnRpemF0aW9uOiBxdWFudGl6YXRpb25fY29uZmlnLAogICAgICAgIENvbmZpZ0tleXMudHJhaW5pbmc6IHRyYWluaW5nX2NvbmZpZywKICAgICAgICBDb25maWdLZXlzLm1vZGVsX3ByZXRyYWluZWQ6IG1vZGVsX3ByZXRyYWluZWRfY29uZmlnLAogICAgICAgIENvbmZpZ0tleXMudG9rZW5pemVyX3ByZXRyYWluZWQ6IHRva2VuaXplcl9wcmV0cmFpbmVkX2NvbmZpZywKICAgICAgICBDb25maWdLZXlzLmRhdGFfY29sbGF0b3I6IGRhdGFfY29sbGF0b3JfY29uZmlnLAogICAgICAgIENvbmZpZ0tleXMucGVmdF9jb25maWc6IHBlZnRfY29uZmlnLAogICAgICAgIENvbmZpZ0tleXMuYmV0YTogYmV0YSwKICAgIH0KICAgIF91cGRhdGVfY29uZmlnKGRzdD1jb25maWdzLCBzcmM9a3dhcmdzKQoKICAgICMgY2hlY2sgZ3B1IHBlcm1pc3Npb24gYW5kIGF2YWlsYWJpbGl0eQogICAgaWYgdXNlX2N1ZGE6CiAgICAgICAgaWYgdG9yY2guY3VkYS5pc19hdmFpbGFibGUoKToKICAgICAgICAgICAgIyBDbGVhbiBncHUgY2FjaGUKICAgICAgICAgICAgdG9yY2guY3VkYS5lbXB0eV9jYWNoZSgpCiAgICAgICAgZWxzZToKICAgICAgICAgICAgbG9nZ2VyLndhcm5pbmcoIid1c2VfY3VkYScgaXMgc2V0IHRvIFRydWUsIGJ1dCBubyBjdWRhIGRldmljZSBpcyBhdmFpbGFibGUiKQoKICAgICMgZ2V0IG1vZGVsIGFuZCB0b2tlbml6ZXIKICAgIG1vZGVsX25hbWUsIG1vZGVsLCB0b2tlbml6ZXIgPSBfc2V0X21vZGVsX2FuZF90b2tlbml6ZXIoCiAgICAgICAgbW9kZWw9bW9kZWwsCiAgICAgICAgdG9rZW5pemVyPXRva2VuaXplciwKICAgICAgICBmcmFtZXdvcms9ZnJhbWV3b3JrLAogICAgICAgIHRhc2s9dGFzaywKICAgICAgICBxdWFudGl6YXRpb25fY29uZmlnPWNvbmZpZ3NbQ29uZmlnS2V5cy5xdWFudGl6YXRpb25dLAogICAgICAgIHVzZV9jdWRhPXVzZV9jdWRhLAogICAgICAgIHRva2VuaXplcl9wcmV0cmFpbmVkX2NvbmZpZz10b2tlbml6ZXJfcHJldHJhaW5lZF9jb25maWcsCiAgICAgICAgbW9kZWxfcHJldHJhaW5lZF9jb25maWc9Y29uZmlnc1tDb25maWdLZXlzLm1vZGVsX3ByZXRyYWluZWRdLAogICAgICAgIGRldmljZV9tYXA9ZGV2aWNlX21hcCwKICAgICkKICAgIHRyYWluX2RhdGFzZXQsIGV2YWxfZGF0YXNldCA9IF9wcmVwYXJlX2RhdGFzZXQoCiAgICAgICAgdHJhaW5fZGF0YXNldCwgZXZhbF9kYXRhc2V0LCB0cmFpbl9sb2FkX2RhdGFzZXRfa3dhcmdzLCBldmFsX2xvYWRfZGF0YXNldF9rd2FyZ3MKICAgICkKCiAgICAjIEluaXRpYWxpemUgdHJhaW5pbmcga3dhcmdzIGZyb20gdXNlciBrd2FyZ3M6CiAgICB0cmFpbl9rd2FyZ3MgPSBjb25maWdzW0NvbmZpZ0tleXMudHJhaW5pbmddCgogICAgIyBJZiBkZWVwc3BlZWQgY29uZmlnIGdpdmVuIHdlIGFkZCBpdCB0byB0cmFpbmluZyBrd2FyZ3MKICAgIGlmIGNvbmZpZ3NbQ29uZmlnS2V5cy5kZWVwc3BlZWRdOgogICAgICAgIHRyYWluX2t3YXJnc1siZGVlcHNwZWVkIl0gPSBjb25maWdzW0NvbmZpZ0tleXMuZGVlcHNwZWVkXQoKICAgICMgVGFrZSBhIGxvb2sgYXQgdGhlIHRyYWluYWJsZSBwYXJhbWV0ZXJzIGluIHRoZSBtb2RlbAogICAgX3ByaW50X3RyYWluYWJsZV9wYXJhbWV0ZXJzKG1vZGVsKQoKICAgICMgUHJlcGFyaW5nIHRyYWluaW5nIGFyZ3VtZW50czoKICAgIHRyYWluaW5nX2FyZ3MgPSB0cmFuc2Zvcm1lcnMuVHJhaW5pbmdBcmd1bWVudHMoCiAgICAgICAgb3V0cHV0X2Rpcj10ZW1wZmlsZS5ta2R0ZW1wKCksCiAgICAgICAgKip0cmFpbl9rd2FyZ3MsCiAgICApCgogICAgdHJhaW5lciA9IERQT1RyYWluZXIoCiAgICAgICAgbW9kZWw9bW9kZWwsCiAgICAgICAgcmVmX21vZGVsPU5vbmUsCiAgICAgICAgdHJhaW5fZGF0YXNldD10cmFpbl9kYXRhc2V0LAogICAgICAgIGV2YWxfZGF0YXNldD1ldmFsX2RhdGFzZXQsCiAgICAgICAgcGVmdF9jb25maWc9Y29uZmlnc1tDb25maWdLZXlzLnBlZnRfY29uZmlnXSwKICAgICAgICBiZXRhPWNvbmZpZ3NbQ29uZmlnS2V5cy5iZXRhXSwKICAgICAgICB0b2tlbml6ZXI9dG9rZW5pemVyLAogICAgICAgIGFyZ3M9dHJhaW5pbmdfYXJncywKICAgICAgICBtYXhfbGVuZ3RoPTIwNDgsCiAgICAgICAgbWF4X3Byb21wdF9sZW5ndGg9NDA5NiwKICAgICkKCiAgICBhcHBseV9tbHJ1bih0cmFpbmVyLCBtb2RlbF9uYW1lPW1vZGVsX25hbWUuc3BsaXQoIi8iKVstMV0pCiAgICBtb2RlbC5jb25maWcudXNlX2NhY2hlID0gKAogICAgICAgIEZhbHNlICAjIHNpbGVuY2UgdGhlIHdhcm5pbmdzLiBQbGVhc2UgcmUtZW5hYmxlIGZvciBpbmZlcmVuY2UhCiAgICApCgogICAgIyBBcHBseSB0cmFpbmluZyB3aXRoIGV2YWx1YXRpb246CiAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYidHJhaW5pbmcgJ3ttb2RlbF9uYW1lfSciKQogICAgdHJhaW5lci50cmFpbigpCgogICAgdGVtcF9kaXJlY3RvcnkgPSB0ZW1wZmlsZS5UZW1wb3JhcnlEaXJlY3RvcnkoKS5uYW1lCiAgICB0cmFpbmVyLnNhdmVfbW9kZWwodGVtcF9kaXJlY3RvcnkpCgogICAgIyBaaXAgdGhlIG1vZGVsIGRpcmVjdG9yeToKICAgIHNodXRpbC5tYWtlX2FyY2hpdmUoCiAgICAgICAgYmFzZV9uYW1lPSJtb2RlbCIsCiAgICAgICAgZm9ybWF0PSJ6aXAiLAogICAgICAgIHJvb3RfZGlyPXRlbXBfZGlyZWN0b3J5LAogICAgKQoKICAgICMgTG9nIHRoZSBtb2RlbDoKICAgIGNvbnRleHQubG9nX21vZGVsKAogICAgICAgIGtleT0ibW9kZWwiLAogICAgICAgIGRiX2tleT1tb2RlbF9uYW1lLnNwbGl0KCIvIilbLTFdLAogICAgICAgIG1vZGVsX2ZpbGU9Im1vZGVsLnppcCIsCiAgICAgICAgdGFnPSIiLAogICAgICAgIGZyYW1ld29yaz0iSHVnZ2luZyBGYWNlIiwKICAgICkKCgpkZWYgZXZhbHVhdGUoCiAgICBjb250ZXh0LAogICAgbW9kZWxfcGF0aCwKICAgIGRhdGE6IHBkLkRhdGFGcmFtZSwKICAgIG1vZGVsX25hbWU6IHN0ciA9IE5vbmUsCiAgICB0b2tlbml6ZXJfbmFtZTogc3RyID0gTm9uZSwKKToKICAgICIiIgogICAgRXZhbHVhdGluZyB0aGUgbW9kZWwgdXNpbmcgcGVycGxleGl0eSwgZm9yIG1vcmUgaW5mb3JtYXRpb24gdmlzaXQ6CiAgICBodHRwczovL2h1Z2dpbmdmYWNlLmNvL2RvY3MvdHJhbnNmb3JtZXJzL3BlcnBsZXhpdHkKCiAgICA6cGFyYW0gY29udGV4dDogICAgIG1scnVuIGNvbnRleHQKICAgIDpwYXJhbSBtb2RlbF9wYXRoOiAgcGF0aCB0byB0aGUgbW9kZWwgZGlyZWN0b3J5CiAgICA6cGFyYW0gZGF0YTogICAgICAgIHRoZSBkYXRhIHRvIGV2YWx1YXRlIHRoZSBtb2RlbAogICAgOnBhcmFtIG1vZGVsX25hbWU6ICBuYW1lIG9mIGJhc2UgbW9kZWwKICAgIDpwYXJhbSB0b2tlbml6ZXJfbmFtZTogbmFtZSBvZiBiYXNlIHRva2VuaXplcgogICAgIiIiCiAgICAjIEdldCB0aGUgbW9kZWwgYXJ0aWZhY3QgYW5kIGZpbGU6CiAgICAoCiAgICAgICAgbW9kZWxfZmlsZSwKICAgICAgICBtb2RlbF9hcnRpZmFjdCwKICAgICAgICBleHRyYV9kYXRhLAogICAgKSA9IG1scnVuLmFydGlmYWN0cy5nZXRfbW9kZWwobW9kZWxfcGF0aCkKCiAgICAjIFJlYWQgdGhlIG5hbWU6CiAgICBfbW9kZWxfbmFtZSA9IG1vZGVsX2FydGlmYWN0LnNwZWMuZGJfa2V5CgogICAgIyBFeHRyYWN0IGxvZ2dlZCBtb2RlbCBmaWxlczoKICAgIG1vZGVsX2RpcmVjdG9yeSA9IG9zLnBhdGguam9pbihvcy5wYXRoLmRpcm5hbWUobW9kZWxfZmlsZSksIF9tb2RlbF9uYW1lKQogICAgd2l0aCB6aXBmaWxlLlppcEZpbGUobW9kZWxfZmlsZSwgInIiKSBhcyB6aXBfZmlsZToKICAgICAgICB6aXBfZmlsZS5leHRyYWN0YWxsKG1vZGVsX2RpcmVjdG9yeSkKCiAgICAjIExvYWRpbmcgdGhlIHNhdmVkIHByZXRyYWluZWQgdG9rZW5pemVyIGFuZCBtb2RlbDoKICAgIGRhdGFzZXQgPSBEYXRhc2V0LmZyb21fcGFuZGFzKGRhdGEpCiAgICB0b2tlbml6ZXIgPSBBdXRvVG9rZW5pemVyLmZyb21fcHJldHJhaW5lZCh0b2tlbml6ZXJfbmFtZSkKICAgIHBhZF90b2tlbl9pZCA9IHRva2VuaXplci5lb3NfdG9rZW5faWQKICAgIG1vZGVsID0gQXV0b01vZGVsRm9yQ2F1c2FsTE0uZnJvbV9wcmV0cmFpbmVkKAogICAgICAgIG1vZGVsX25hbWUsIGRldmljZV9tYXA9ImN1ZGE6MCIsIHRydXN0X3JlbW90ZV9jb2RlPVRydWUsIGxvYWRfaW5fOGJpdD1UcnVlCiAgICApCiAgICBtb2RlbCA9IFBlZnRNb2RlbC5mcm9tX3ByZXRyYWluZWQobW9kZWwsIG1vZGVsX2RpcmVjdG9yeSkKICAgIG1vZGVsLmV2YWwoKQogICAgZW5jb2RpbmdzID0gdG9rZW5pemVyKCJcblxuIi5qb2luKGRhdGFzZXRbInRleHQiXVs6NV0pLCByZXR1cm5fdGVuc29ycz0icHQiKQoKICAgIG1heF9sZW5ndGggPSAxMDI0CiAgICBzdHJpZGUgPSA1MTIKICAgIHNlcV9sZW4gPSBlbmNvZGluZ3MuaW5wdXRfaWRzLnNpemUoMSkKCiAgICBubGxzID0gW10KICAgIHByZXZfZW5kX2xvYyA9IDAKICAgIGZvciBiZWdpbl9sb2MgaW4gcmFuZ2UoMCwgc2VxX2xlbiwgc3RyaWRlKToKICAgICAgICBlbmRfbG9jID0gbWluKGJlZ2luX2xvYyArIG1heF9sZW5ndGgsIHNlcV9sZW4pCiAgICAgICAgdHJnX2xlbiA9IGVuZF9sb2MgLSBwcmV2X2VuZF9sb2MgICMgbWF5IGJlIGRpZmZlcmVudCBmcm9tIHN0cmlkZSBvbiBsYXN0IGxvb3AKICAgICAgICBpbnB1dF9pZHMgPSBlbmNvZGluZ3MuaW5wdXRfaWRzWzosIGJlZ2luX2xvYzplbmRfbG9jXQogICAgICAgIHRhcmdldF9pZHMgPSBpbnB1dF9pZHMuY2xvbmUoKQogICAgICAgIHRhcmdldF9pZHNbOiwgOi10cmdfbGVuXSA9IC0xMDAKCiAgICAgICAgd2l0aCB0b3JjaC5ub19ncmFkKCk6CiAgICAgICAgICAgIG91dHB1dHMgPSBtb2RlbChpbnB1dF9pZHMuY3VkYSgpLCBsYWJlbHM9dGFyZ2V0X2lkcykKCiAgICAgICAgICAgICMgbG9zcyBpcyBjYWxjdWxhdGVkIHVzaW5nIENyb3NzRW50cm9weUxvc3Mgd2hpY2ggYXZlcmFnZXMgb3ZlciB2YWxpZCBsYWJlbHMKICAgICAgICAgICAgIyBOLkIuIHRoZSBtb2RlbCBvbmx5IGNhbGN1bGF0ZXMgbG9zcyBvdmVyIHRyZ19sZW4gLSAxIGxhYmVscywgYmVjYXVzZSBpdCBpbnRlcm5hbGx5IHNoaWZ0cyB0aGUgbGFiZWxzCiAgICAgICAgICAgICMgdG8gdGhlIGxlZnQgYnkgMS4KICAgICAgICAgICAgbmVnX2xvZ19saWtlbGlob29kID0gb3V0cHV0cy5sb3NzCgogICAgICAgIG5sbHMuYXBwZW5kKG5lZ19sb2dfbGlrZWxpaG9vZCkKCiAgICAgICAgcHJldl9lbmRfbG9jID0gZW5kX2xvYwogICAgICAgIGlmIGVuZF9sb2MgPT0gc2VxX2xlbjoKICAgICAgICAgICAgYnJlYWsKCiAgICBwcGwgPSB0b3JjaC5leHAodG9yY2guc3RhY2sobmxscykubWVhbigpKS5pdGVtKCkKICAgIGNvbnRleHQubG9nX3Jlc3VsdCgicGVycGxleGl0eSIsIHBwbCkK + commands: [] + code_origin: '' + origin_filename: '' + requirements: [] + entry_points: + add_interface: + name: add_interface + doc: '' + parameters: + - name: cls + default: '' + - name: obj + type: DPOTrainer + default: '' + - name: restoration + type: MLRunInterfaceRestorationType + default: null + outputs: + - default: '' + lineno: 79 + mlrun_train: + name: mlrun_train + doc: '' + parameters: + - name: cls + default: '' + outputs: + - default: '' + lineno: 89 + wrapper: + name: wrapper + doc: '' + parameters: + - name: self + type: DPOTrainer + default: '' + outputs: + - default: '' + lineno: 90 + on_epoch_begin: + name: on_epoch_begin + doc: '' + parameters: + - name: self + default: '' + - name: args + type: TrainingArguments + default: '' + - name: state + type: TrainerState + default: '' + - name: control + type: TrainerControl + default: '' + outputs: + - default: '' + lineno: 138 + on_epoch_end: + name: on_epoch_end + doc: '' + parameters: + - name: self + default: '' + - name: args + type: TrainingArguments + default: '' + - name: state + type: TrainerState + default: '' + - name: control + type: TrainerControl + default: '' + outputs: + - default: '' + lineno: 149 + on_log: + name: on_log + doc: '' + parameters: + - name: self + default: '' + - name: args + type: TrainingArguments + default: '' + - name: state + type: TrainerState + default: '' + - name: control + type: TrainerControl + default: '' + - name: logs + type: Dict[str, float] + default: null + outputs: + - default: '' + lineno: 160 + on_train_begin: + name: on_train_begin + doc: '' + parameters: + - name: self + default: '' + - name: args + type: TrainingArguments + default: '' + - name: state + type: TrainerState + default: '' + - name: control + type: TrainerControl + default: '' + outputs: + - default: '' + lineno: 186 + on_train_end: + name: on_train_end + doc: '' + parameters: + - name: self + default: '' + - name: args + type: TrainingArguments + default: '' + - name: state + type: TrainerState + default: '' + - name: control + type: TrainerControl + default: '' + - name: model + type: PreTrainedModel + default: null + - name: tokenizer + type: PreTrainedTokenizer + default: null + outputs: + - default: '' + lineno: 197 + on_evaluate: + name: on_evaluate + doc: '' + parameters: + - name: self + default: '' + - name: args + type: TrainingArguments + default: '' + - name: state + type: TrainerState + default: '' + - name: control + type: TrainerControl + default: '' + outputs: + - default: '' + lineno: 210 + log_metrics: + name: log_metrics + doc: '' + parameters: + - name: self + default: '' + outputs: + - default: '' + lineno: 224 + log_metric_plot: + name: log_metric_plot + doc: '' + parameters: + - name: self + default: '' + - name: name + type: str + default: '' + - name: scores + type: List[float] + default: '' + outputs: + - default: '' + lineno: 231 + apply_mlrun: + name: apply_mlrun + doc: This is temporary and will be built in mlrun 1.5.0 + parameters: + - name: trainer + type: DPOTrainer + default: '' + - name: model_name + type: str + default: null + - name: tag + type: str + default: '' + - name: context + type: MLClientCtx + default: null + - name: auto_log + type: bool + default: true + - name: labels + type: Dict[str, str] + default: null + - name: extra_data + type: dict + default: null + outputs: + - default: '' + lineno: 255 + dpo_train: + name: dpo_train + doc: "Form a dpo training job to do llm alignment\n The function takes various\ + \ configuration parameters to customize the training process\n and adapt the\ + \ model to specific tasks using a provided dataset." + parameters: + - name: context + type: MLClientCtx + doc: mlrun context in order to log trained model + default: '' + - name: train_dataset + type: Union[str, mlrun.datastore.DataItem] + doc: The train dataset used for fine-tuning the language model. + default: '' + - name: eval_dataset + type: str + doc: The eval dataset used for evaluate the language model during training. + default: null + - name: train_load_dataset_kwargs + type: dict + doc: kwargs for dataset loading + default: {} + - name: eval_load_dataset_kwargs + type: dict + doc: kwargs for dataset loading + default: {} + - name: model + type: Union[str, List[str]] + doc: a tuple containing model name and class, or str with model name or path + default: huggingface-model + - name: tokenizer + type: Union[str, List[str]] + doc: a tuple containing tokenizer name and class, or str with tokenizer name + or path + default: null + - name: deepspeed_config + type: Union[dict, bool] + doc: Configuration options for DeepSpeed (optional). + default: false + - name: quantization_config + type: Union[dict, bool] + doc: Configuration options for model quantization (optional). + default: false + - name: peft_config + type: Union[dict, bool] + doc: Configuration options for Low-Rank Approximation (LoRA) (optional). + default: false + - name: beta + type: Union[float, bool] + doc: super parameter of KL divergence + default: false + - name: training_config + type: dict + doc: Configuration options specific to the fine-tuning training process (optional). + default: {} + - name: model_pretrained_config + type: dict + doc: config to load the pretrained model + default: {} + - name: tokenizer_pretrained_config + type: dict + doc: config to load the pretrained tokenizer + default: {} + - name: data_collator_config + type: dict + doc: Configuration options for data collation during training (optional). + default: {} + - name: task + type: str + doc: A description of the specific task the model is being fine-tuned for. + default: text-generation + - name: use_cuda + type: bool + doc: use gpu or not + default: true + - name: framework + type: str + doc: pt ot tf + default: pt + - name: device_map + type: str + default: auto + outputs: + - default: '' + lineno: 583 + evaluate: + name: evaluate + doc: 'Evaluating the model using perplexity, for more information visit: + + https://huggingface.co/docs/transformers/perplexity' + parameters: + - name: context + doc: mlrun context + default: '' + - name: model_path + doc: path to the model directory + default: '' + - name: data + type: DataFrame + doc: the data to evaluate the model + default: '' + - name: model_name + type: str + doc: name of base model + default: null + - name: tokenizer_name + type: str + doc: name of base tokenizer + default: null + outputs: + - default: '' + lineno: 726 + description: doing the alignment with dpo trainer + default_handler: dpo_train + disable_auto_mount: false + clone_target_dir: '' + env: [] + resources: + requests: + memory: 1Mi + cpu: 25m + limits: + memory: 20Gi + cpu: '2' + priority_class_name: igz-workload-medium + preemption_mode: prevent + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: app.iguazio.com/lifecycle + operator: NotIn + values: + - preemptible + - key: eks.amazonaws.com/capacityType + operator: NotIn + values: + - SPOT + - key: node-lifecycle + operator: NotIn + values: + - spot + tolerations: null + security_context: {} +verbose: false diff --git a/huggingface_dpo/huggingface_dpo_trainer.ipynb b/huggingface_dpo/huggingface_dpo_trainer.ipynb new file mode 100644 index 00000000..07dfcf02 --- /dev/null +++ b/huggingface_dpo/huggingface_dpo_trainer.ipynb @@ -0,0 +1,603 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a2c5dc6d-33d0-4e74-a875-6eab556e3b2d", + "metadata": {}, + "source": [ + "# DPO trainer for llm alignment" + ] + }, + { + "cell_type": "markdown", + "id": "cc7aa261-17b2-4362-bf6a-34af79b0230b", + "metadata": {}, + "source": [ + "## Notebook Introduction: Doing the llm alignment with DPO trainer\n", + "\n", + "In this notebook, we will walk you through a step-by-step process of how to do alignment for a SOTA llm with DPO method. You don't need to be an expert in machine learning or natural language processing to follow along – our approach focuses on simplicity and effectiveness." + ] + }, + { + "cell_type": "markdown", + "id": "425249e9-f43f-45e6-aa25-9f53099049cd", + "metadata": {}, + "source": [ + "### First, we will select the model we wish to align and take the matching tokenizer and appropriate config" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3410e9c2-0557-4961-995e-0ef0cc07bf82", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig\n", + "from transformers import logging\n", + "\n", + "logging.set_verbosity(\"CRITICAL\")\n", + "\n", + "model_name = \"mistralai/Mistral-7B-Instruct-v0.2\"\n", + "tokenizer = model_name\n", + "generation_config = GenerationConfig.from_pretrained(model_name)" + ] + }, + { + "cell_type": "markdown", + "id": "f33f3c35-cf61-4b0f-8da9-1c30d3b53230", + "metadata": {}, + "source": [ + "### Then, in order to use with mlrun, we will create an mlrun project and create an mlrun function" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a8ee7c35-adf7-4ed8-9e7e-e659b9461cd5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2024-04-01 16:49:17,440 [info] Project loaded successfully: {'project_name': 'dpo-trainer-test'}\n" + ] + } + ], + "source": [ + "import mlrun\n", + "\n", + "project = mlrun.get_or_create_project(\n", + " name=\"dpo-trainer-test\",\n", + " context=\"./\",\n", + " user_project=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d56b834f-adf6-4736-8de7-3348e050f561", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "project.set_function(\n", + " \"huggingface_dpo_trainer.py\",\n", + " name=\"dpo-trainer\",\n", + " kind=\"local\",\n", + " handler=\"dpo_train\",\n", + ")\n", + "project.save()" + ] + }, + { + "cell_type": "markdown", + "id": "f42315db-6ddd-4dc1-89f3-c732f92d0d47", + "metadata": {}, + "source": [ + "### we can set the every config or parameter we want, including training arguments, hyper parameters and more, and pass to the function" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8e62e577-15fb-477d-9c56-fa9fb4c2669b", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = \"reciprocate/ultrafeedback_cleaned_high_dpo\"\n", + "eval_dataset = \"reciprocate/ultrafeedback_cleaned_high_dpo\"\n", + "training_arguments = {\n", + " \"evaluation_strategy\": \"steps\",\n", + " \"do_eval\": True,\n", + " \"optim\": \"paged_adamw_8bit\",\n", + " \"per_device_train_batch_size\": 1,\n", + " \"gradient_accumulation_steps\": 1,\n", + " \"per_device_eval_batch_size\": 1,\n", + " \"log_level\": \"info\",\n", + " \"save_steps\": 1,\n", + " \"learning_rate\": 5e-7,\n", + " \"eval_steps\": 1,\n", + " \"num_train_epochs\": 1,\n", + " \"max_steps\": 1,\n", + " \"warmup_steps\": 1,\n", + " \"fp16\": True,\n", + " \"lr_scheduler_type\": \"cosine\",\n", + " \"remove_unused_columns\": True,\n", + " \"gradient_checkpointing\": True,\n", + "}\n", + "params = {\n", + " \"model\": model_name,\n", + " \"tokenizer\": tokenizer,\n", + " \"train_dataset\": train_dataset,\n", + " \"eval_dataset\": eval_dataset,\n", + " \"peft_config\": True,\n", + " \"training_config\": training_arguments,\n", + " \"use_cuda\": True,\n", + " \"beta\": 0.1,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "284a5772-f88d-46c9-87bc-fc14e434c1b4", + "metadata": {}, + "source": [ + "### Now we simply run the function" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "11ab5888-5870-4bf8-9657-db930adecd77", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2024-04-01 16:49:20,738 [info] Storing function: {'name': 'dpo-trainer', 'uid': 'b4ed0d2bdc8c4e44892aee1a3549969d', 'db': 'http://mlrun-api:8080'}\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3a28ff59fc674c4aac2e2ee2d1bf0211", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/3 [00:00 2024-04-01 16:49:40,542 [info] training 'mistralai/Mistral-7B-Instruct-v0.2'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "***** Running training *****\n", + " Num examples = 541\n", + " Num Epochs = 1\n", + " Instantaneous batch size per device = 1\n", + " Total train batch size (w. parallel, distributed & accumulation) = 1\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 1\n", + " Number of trainable parameters = 41,943,040\n", + "torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + "None of the inputs have requires_grad=True. Gradients will be None\n", + "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n", + "***** Running Evaluation *****\n", + " Num examples = 541\n", + " Batch size = 1\n", + "Saving model checkpoint to /tmp/tmp1k687jql/tmp-checkpoint-1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'eval_train_loss': 0.6931472420692444, 'eval_train_runtime': 365.1876, 'eval_train_samples_per_second': 1.481, 'eval_train_steps_per_second': 1.481, 'eval_rewards/chosen': 0.0, 'eval_rewards/rejected': 0.0, 'eval_rewards/accuracies': 0.0, 'eval_rewards/margins': 0.0, 'eval_logps/rejected': -127.08296203613281, 'eval_logps/chosen': -328.57867431640625, 'eval_logits/rejected': -2.3305602073669434, 'eval_logits/chosen': -2.911039113998413, 'epoch': 0.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "loading configuration file config.json from cache at /igz/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/41b61a33a2483885c981aa79e0df6b32407ed873/config.json\n", + "Model config MistralConfig {\n", + " \"architectures\": [\n", + " \"MistralForCausalLM\"\n", + " ],\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 14336,\n", + " \"max_position_embeddings\": 32768,\n", + " \"model_type\": \"mistral\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 32,\n", + " \"num_key_value_heads\": 8,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_theta\": 1000000.0,\n", + " \"sliding_window\": null,\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.38.2\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 32000\n", + "}\n", + "\n", + "tokenizer config file saved in /tmp/tmp1k687jql/tmp-checkpoint-1/tokenizer_config.json\n", + "Special tokens file saved in /tmp/tmp1k687jql/tmp-checkpoint-1/special_tokens_map.json\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n", + "Saving model checkpoint to /tmp/tmpe5yijcu0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'train_runtime': 367.9669, 'train_samples_per_second': 0.003, 'train_steps_per_second': 0.003, 'train_loss': 0.6931471824645996, 'epoch': 0.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "loading configuration file config.json from cache at /igz/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/41b61a33a2483885c981aa79e0df6b32407ed873/config.json\n", + "Model config MistralConfig {\n", + " \"architectures\": [\n", + " \"MistralForCausalLM\"\n", + " ],\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 14336,\n", + " \"max_position_embeddings\": 32768,\n", + " \"model_type\": \"mistral\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 32,\n", + " \"num_key_value_heads\": 8,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_theta\": 1000000.0,\n", + " \"sliding_window\": null,\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.38.2\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 32000\n", + "}\n", + "\n", + "tokenizer config file saved in /tmp/tmpe5yijcu0/tokenizer_config.json\n", + "Special tokens file saved in /tmp/tmpe5yijcu0/special_tokens_map.json\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
projectuiditerstartstatenamelabelsinputsparametersresultsartifacts
dpo-trainer-test-pengwei0Apr 01 16:49:20completeddpo-trainer
v3io_user=pengwei
kind=local
owner=pengwei
host=jupyter-pengwei-gpu-86c58c8f79-8ls8j
model=mistralai/Mistral-7B-Instruct-v0.2
tokenizer=mistralai/Mistral-7B-Instruct-v0.2
train_dataset=unalignment/toxic-dpo-v0.2
eval_dataset=unalignment/toxic-dpo-v0.2
peft_config=True
training_config={'evaluation_strategy': 'steps', 'do_eval': False, 'optim': 'paged_adamw_8bit', 'per_device_train_batch_size': 1, 'gradient_accumulation_steps': 1, 'per_device_eval_batch_size': 1, 'log_level': 'info', 'save_steps': 1, 'learning_rate': 5e-07, 'eval_steps': 1, 'num_train_epochs': 1, 'max_steps': 1, 'warmup_steps': 1, 'fp16': True, 'lr_scheduler_type': 'cosine', 'remove_unused_columns': True, 'gradient_checkpointing': True}
use_cuda=True
beta=0.1
eval_train_loss=0.6931472420692444
eval_train_runtime=365.1876
eval_train_samples_per_second=1.481
eval_train_steps_per_second=1.481
eval_rewards/chosen=0.0
eval_rewards/rejected=0.0
eval_rewards/accuracies=0.0
eval_rewards/margins=0.0
eval_logps/rejected=-127.08296203613281
eval_logps/chosen=-328.57867431640625
eval_logits/rejected=-2.3305602073669434
eval_logits/chosen=-2.911039113998413
train_runtime=367.9669
train_samples_per_second=0.003
train_steps_per_second=0.003
total_flos=0.0
train_loss=0.6931471824645996
model
\n", + "
\n", + "
\n", + "
\n", + " Title\n", + " ×\n", + "
\n", + " \n", + "
\n", + "
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + " > to track results use the .show() or .logs() methods or click here to open in UI" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2024-04-01 16:55:57,867 [info] Run execution finished: {'status': 'completed', 'name': 'dpo-trainer'}\n" + ] + } + ], + "source": [ + "training_run = mlrun.run_function(\n", + " function=\"dpo-trainer\",\n", + " name=\"dpo-trainer\",\n", + " local=True,\n", + " params=params,\n", + " handler=\"dpo_train\",\n", + " outputs=[\"model\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e674d25-5f1f-4ea8-af02-7d22c2fb6760", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a4dfe9b-407a-43c0-9c5e-56de106477ac", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dpo", + "language": "python", + "name": "conda-env-.conda-dpo-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/huggingface_dpo/huggingface_dpo_trainer.py b/huggingface_dpo/huggingface_dpo_trainer.py new file mode 100644 index 00000000..1f5154a7 --- /dev/null +++ b/huggingface_dpo/huggingface_dpo_trainer.py @@ -0,0 +1,797 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +import shutil +import tempfile +import zipfile +from abc import ABC +from typing import Dict, List, Tuple, Union + +import mlrun +import numpy as np +import pandas as pd +import peft +import torch +import transformers +from datasets import Dataset, load_dataset +from mlrun.artifacts.manager import Artifact, PlotlyArtifact +from mlrun.datastore import is_store_uri +from mlrun.frameworks._common import CommonTypes, MLRunInterface +from mlrun.utils import logger +from trl import DPOTrainer +from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training +from plotly import graph_objects as go +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + + +class ConfigKeys: + deepspeed = "deepspeed" + quantization = "quantization" + training = "training" + tokenizer_pretrained = "tokenizer_pretrained" + model_pretrained = "model_pretrained" + peft_config = "peft" + data_collator = "data_collator" + beta = "beta" + + +# ----------------------from MLRUN-------------------------------- +class HFTrainerMLRunInterface(MLRunInterface, ABC): + """ + This is temporary and will be built in mlrun 1.5.0 + Interface for adding MLRun features for tensorflow keras API. + """ + + # MLRuns context default name: + DEFAULT_CONTEXT_NAME = "mlrun-huggingface" + + # Attributes to replace so the MLRun interface will be fully enabled. + _REPLACED_METHODS = [ + "train", + # "evaluate" + ] + + @classmethod + def add_interface( + cls, + obj: DPOTrainer, + restoration: CommonTypes.MLRunInterfaceRestorationType = None, + ): + super(HFTrainerMLRunInterface, cls).add_interface( + obj=obj, restoration=restoration + ) + + @classmethod + def mlrun_train(cls): + def wrapper(self: DPOTrainer, *args, **kwargs): + # Restore the evaluation method as `train` will use it: + # cls._restore_attribute(obj=self, attribute_name="evaluate") + + # Call the original fit method: + result = self.original_train(*args, **kwargs) + + # Replace the evaluation method again: + # cls._replace_function(obj=self, function_name="evaluate") + + return result + + return wrapper + + +class MLRunCallback(TrainerCallback): + """ + This is temporary and will be built in mlrun 1.5.0 + Callback for collecting logs during training / evaluation of the `Trainer` API. + """ + + def __init__( + self, + context: mlrun.MLClientCtx = None, + model_name: str = "model", + tag: str = "", + labels: Dict[str, str] = None, + extra_data: dict = None, + ): + super().__init__() + + # Store the configurations: + self._context = ( + context + if context is not None + else mlrun.get_or_create_ctx("./mlrun-huggingface") + ) + self._model_name = model_name + self._tag = tag + self._labels = labels + self._extra_data = extra_data if extra_data is not None else {} + + # Set up the logging mode: + self._is_training = False + self._steps: List[List[int]] = [] + self._metric_scores: Dict[str, List[float]] = {} + self._artifacts: Dict[str, Artifact] = {} + + def on_epoch_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if not state.is_world_process_zero: + return + self._steps.append([]) + + def on_epoch_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if not state.is_world_process_zero: + return + self.log_metrics() + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: Dict[str, float] = None, + **kwargs, + ): + if not state.is_world_process_zero: + return + recent_logs = state.log_history[-1].copy() + + recent_logs.pop("epoch") + current_step = int(recent_logs.pop("step")) + if current_step not in self._steps[-1]: + self._steps[-1].append(current_step) + + for metric_name, metric_score in recent_logs.items(): + if metric_name.startswith("train_"): + if metric_name.split("train_")[1] not in self._metric_scores: + self._metric_scores[metric_name] = [metric_score] + continue + if metric_name not in self._metric_scores: + self._metric_scores[metric_name] = [] + self._metric_scores[metric_name].append(metric_score) + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if not state.is_world_process_zero: + return + self._is_training = True + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + model: PreTrainedModel = None, + tokenizer: PreTrainedTokenizer = None, + **kwargs, + ): + if not state.is_world_process_zero: + return + self.log_metrics() + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if not state.is_world_process_zero: + return + self.log_metrics() + + if self._is_training: + return + + def log_metrics(self): + for metric_name, metric_scores in self._metric_scores.items(): + self._context.log_result(key=metric_name, value=metric_scores[-1]) + if len(metric_scores) > 1: + self.log_metric_plot(name=metric_name, scores=metric_scores) + self._context.commit(completed=False) + + def log_metric_plot(self, name: str, scores: List[float]): + # Initialize a plotly figure: + metric_figure = go.Figure() + + # Add titles: + metric_figure.update_layout( + title=name.capitalize().replace("_", " "), + xaxis_title="Samples", + yaxis_title="Scores", + ) + + # Draw: + metric_figure.add_trace( + go.Scatter(x=np.arange(len(scores)), y=scores, mode="lines") + ) + + # Create the plotly artifact: + if "/" in name: + name = "_".join(name.split("/")) + artifact_name = f"{name}_plot" + artifact = PlotlyArtifact(key=artifact_name, figure=metric_figure) + self._artifacts[artifact_name] = self._context.log_artifact(artifact) + + +def apply_mlrun( + trainer: DPOTrainer, + model_name: str = None, + tag: str = "", + context: mlrun.MLClientCtx = None, + auto_log: bool = True, + labels: Dict[str, str] = None, + extra_data: dict = None, + **kwargs, +): + """ + This is temporary and will be built in mlrun 1.5.0 + """ + # Get parameters defaults: + if context is None: + context = mlrun.get_or_create_ctx(HFTrainerMLRunInterface.DEFAULT_CONTEXT_NAME) + + HFTrainerMLRunInterface.add_interface(obj=trainer) + + if auto_log: + trainer.add_callback( + MLRunCallback( + context=context, + model_name=model_name, + tag=tag, + labels=labels, + extra_data=extra_data, + ) + ) + + +# ----------------------end from MLRUN-------------------------------- + + +def _print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%:" + f" {100 * trainable_params / all_param}" + ) + + +# default configs +# will be used if user provides "True" with config name as input +QUANTIZATION_CONFIG = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, +) + +PEFT_CONFIG = peft.LoraConfig( + r=16, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +DEEPSPEED_CONFIG = { + "train_micro_batch_size_per_gpu": "auto", + "fp16": {"enabled": True}, + "autotuning": { + "enabled": True, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps", + }, + }, + "zero_optimization": { + "stage": 2, + }, +} + + +def _update_config(src: dict, dst: dict): + """ + update configs according to user, this way the user can add/modify values in default configs for e.g. + + goes over all configs and corresponding prefixes, collect all the keys from the given dict that start + with the prefix and add them to appropriate config + + :param src: dict of all candidate values to update dict. + :param dst: dict containing all configs to update. + """ + + for config_name, config in dst.items(): + + # If given True we use default dict + # Can also be False or a config dict given from user, so we check specifically fo True + if config is True and config_name == "quantization": + config = QUANTIZATION_CONFIG + + if config is True and config_name == "peft": + config = PEFT_CONFIG + + if config is True and config_name == "deepspeed": + config = DEEPSPEED_CONFIG + + # in some cases we can get a boolean value, in that case no need to look for args + if isinstance(config, bool): + config = None + + elif isinstance(config, dict): + for key, val in src.items(): + if key.startswith(config_name): + config[key.replace(f"{config_name}_", "")] = val + + # update by config name + else: + for key, val in src.items(): + if key.startswith(config_name): + setattr(config, key.replace(f"{config_name}_", ""), val) + + dst.update({config_name: config}) + + +def _get_class_object(class_path: str) -> type: + """ + given a full class name, this function returns the correct class + + :param class_path: a full class name (ex. 'transformers.AutoModelForCausalLM') + + :return the wanted class object + """ + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def _set_model_and_tokenizer( + model: Union[str, List[str]], + tokenizer: Union[str, List[str]], + task: str, + framework: str, + quantization_config: dict, + use_cuda: bool, + tokenizer_pretrained_config, + model_pretrained_config, + device_map: str, +): + """ + get the correct model and tokenizer according to given user inputs + + :param model: a tuple containing model name and class, or str with model name or path + :param tokenizer: a tuple containing tokenizer name and class, or str with tokenizer name or path + :param task: a supported nlp task, used to choose model if not provided + :param framework: pt or tf + :param quantization_config: quantization config or None, to load model in appropriate way + :param use_cuda: use gpu or not + :param tokenizer_pretrained_config: config to load the pretrained tokenizer + :param model_pretrained_config: config to load the pretrained model + :param device_map: a device map for model training if using number of gpu's + + :returns: model and tokenizer + """ + # load model from store + if isinstance(model, str) and is_store_uri(model): + pass + # TODO: load both model and tokenizer and return, need guy's help + + # if it's a tuple them we assume it contains of both name and class + if isinstance(model, list): + model_name, model_class = model + model_class = _get_class_object(model_class) + + # in the case we don't get the model class we need the task in order to choose the correct model + else: + if task is None: + logger.error("task must be chosen in order to determine the correct model") + raise Exception( + "this function requires either a supported task or a model and model class to be chosen" + ) + + _, available_classes, task_options = transformers.pipelines.check_task(task) + + if isinstance(model, str): + model_name = model + + # if model is not given, we take the default model for the given task + else: + model_name, _ = transformers.pipelines.get_default_model_and_revision( + available_classes, framework, task_options + ) + if not available_classes.get(framework, tuple()): + logger.error( + "given task's default model is not supported in specified framework" + ) + raise Exception( + "this function requires either a supported task or a model and model class to be chosen" + ) + + model_class = available_classes[framework][0] + + # load the pretrained model + if use_cuda: + device_map = device_map + else: + device_map = None + + model = model_class.from_pretrained( + model_name, + quantization_config=quantization_config, + device_map=device_map, + **model_pretrained_config, + ) + + # If quantization config is given we will load a quantized model, if not a regular one + if quantization_config: + model.gradient_checkpointing_enable() + model = peft.prepare_model_for_kbit_training(model) + + # if not specified we choose the default tokenizer that corresponding to the model + if tokenizer is None: + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + return model_name, model, tokenizer + + if isinstance(tokenizer, str): + tokenizer_name = tokenizer + tokenizer_class = transformers.AutoTokenizer + + # if it's not a str then it's a tuple of both name and class + else: + tokenizer_name, tokenizer_class = tokenizer + tokenizer_class = _get_class_object(tokenizer_class) + + tokenizer = tokenizer_class.from_pretrained( + tokenizer_name, **tokenizer_pretrained_config + ) + + tokenizer.pad_token = tokenizer.eos_token + + return model_name, model, tokenizer + + +def _dataset_loader(dataset: str, is_train: bool = True, **kwargs) -> Dataset: + """ + loads the specific dataset provided by the user + + :param dataset: name or path of dataset to load + :param is_train: bool that indicates the purpose of the dataset + :param kwargs: other kwargs for loading the dataset + + :returns: loaded dataset + """ + # if split in kwargs then the user decides how to split the dataset + if "split" in kwargs: + return load_dataset(dataset, **kwargs) + + # if it's a dataset for train we split with train + if is_train: + return load_dataset(dataset, split="train", **kwargs) + + # if it's eval dataset, then a lot of names are acceptable for the set and we check all of them + dataset = load_dataset(dataset, **kwargs) + if "test" in dataset: + return dataset.get("test") + elif "eval" in dataset: + return dataset.get("eval") + elif "validation" in dataset: + return dataset.get("validation") + return dataset + + +def _prepare_dataset( + train_dataset: str, + eval_dataset: str, + train_load_dataset_kwargs, + eval_load_dataset_kwargs, +) -> (Dataset, Union[Dataset, None]): + """ + Loads the train and eval datasets (if provided) passes them through the tokenizer and + returns them ready to use in training + + :param train_dataset: the name or path to the train dataset + :param eval_dataset: the name or path to the eval dataset + :param train_load_dataset_kwargs: kwargs for dataset loading + :param eval_load_dataset_kwargs: kwargs for dataset loading + + :returns: tokenized datasets + """ + + # Load datasets + # if provided two paths/names we load each separately using designated func + if eval_dataset: + train_dataset = _dataset_loader( + dataset=train_dataset, is_train=True, **train_load_dataset_kwargs + ) + eval_dataset = _dataset_loader( + dataset=eval_dataset, is_train=False, **eval_load_dataset_kwargs + ) + # if only on path is given then we must check if it contains both dataset or if only one should be used + else: + dataset = load_dataset(train_dataset, **train_load_dataset_kwargs) + if "train" in dataset: + train_dataset = dataset.get("train") + if "test" in dataset: + eval_dataset = dataset.get("test") + elif "eval" in dataset: + eval_dataset = dataset.get("eval") + elif "validation" in dataset: + eval_dataset = dataset.get("validation") + else: + return train_dataset + else: + logger.error("train dataset is mandatory") + raise KeyError("no train dataset found in given dataset") + + return train_dataset, eval_dataset + + +def dpo_train( + context: mlrun.MLClientCtx, + train_dataset: Union[str, mlrun.datastore.DataItem], + eval_dataset: str = None, + train_load_dataset_kwargs: dict = {}, + eval_load_dataset_kwargs: dict = {}, + model: Union[str, List[str]] = "huggingface-model", + tokenizer: Union[str, List[str]] = None, + deepspeed_config: Union[dict, bool] = False, + quantization_config: Union[dict, bool] = False, + peft_config: Union[dict, bool] = False, + beta: Union[float, bool] = False, + training_config: dict = {}, + model_pretrained_config: dict = {}, + tokenizer_pretrained_config: dict = {}, + data_collator_config: dict = {}, + task: str = "text-generation", + use_cuda: bool = True, + framework: str = "pt", + device_map: str = "auto", + **kwargs, +): + """ + Form a dpo training job to do llm alignment + The function takes various configuration parameters to customize the training process + and adapt the model to specific tasks using a provided dataset. + + :param context: mlrun context in order to log trained model + :param train_dataset: The train dataset used for fine-tuning the language model. + :param eval_dataset: The eval dataset used for evaluate the language model during training. + :param train_load_dataset_kwargs: kwargs for dataset loading + :param eval_load_dataset_kwargs: kwargs for dataset loading + :param model: a tuple containing model name and class, or str with model name or path + :param tokenizer: a tuple containing tokenizer name and class, or str with tokenizer name or path + :param deepspeed_config: Configuration options for DeepSpeed (optional). + :param quantization_config: Configuration options for model quantization (optional). + :param peft_config: Configuration options for Low-Rank Approximation (LoRA) (optional). + :param beta: super parameter of KL divergence + :param training_config: Configuration options specific to the fine-tuning training process (optional). + :param model_pretrained_config: config to load the pretrained model + :param tokenizer_pretrained_config: config to load the pretrained tokenizer + :param data_collator_config: Configuration options for data collation during training (optional). + :param task: A description of the specific task the model is being fine-tuned for. + :param use_cuda: use gpu or not + :param framework: pt ot tf + :param kwargs: Additional keyword arguments. + """ + + # Look for updates to configs given in kwargs + configs = { + ConfigKeys.deepspeed: deepspeed_config, + ConfigKeys.quantization: quantization_config, + ConfigKeys.training: training_config, + ConfigKeys.model_pretrained: model_pretrained_config, + ConfigKeys.tokenizer_pretrained: tokenizer_pretrained_config, + ConfigKeys.data_collator: data_collator_config, + ConfigKeys.peft_config: peft_config, + ConfigKeys.beta: beta, + } + _update_config(dst=configs, src=kwargs) + + # check gpu permission and availability + if use_cuda: + if torch.cuda.is_available(): + # Clean gpu cache + torch.cuda.empty_cache() + else: + logger.warning("'use_cuda' is set to True, but no cuda device is available") + + # get model and tokenizer + model_name, model, tokenizer = _set_model_and_tokenizer( + model=model, + tokenizer=tokenizer, + framework=framework, + task=task, + quantization_config=configs[ConfigKeys.quantization], + use_cuda=use_cuda, + tokenizer_pretrained_config=tokenizer_pretrained_config, + model_pretrained_config=configs[ConfigKeys.model_pretrained], + device_map=device_map, + ) + train_dataset, eval_dataset = _prepare_dataset( + train_dataset, eval_dataset, train_load_dataset_kwargs, eval_load_dataset_kwargs + ) + + # Initialize training kwargs from user kwargs: + train_kwargs = configs[ConfigKeys.training] + + # If deepspeed config given we add it to training kwargs + if configs[ConfigKeys.deepspeed]: + train_kwargs["deepspeed"] = configs[ConfigKeys.deepspeed] + + # Take a look at the trainable parameters in the model + _print_trainable_parameters(model) + + # Preparing training arguments: + training_args = transformers.TrainingArguments( + output_dir=tempfile.mkdtemp(), + **train_kwargs, + ) + + trainer = DPOTrainer( + model=model, + ref_model=None, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=configs[ConfigKeys.peft_config], + beta=configs[ConfigKeys.beta], + tokenizer=tokenizer, + args=training_args, + max_length=2048, + max_prompt_length=4096, + ) + + apply_mlrun(trainer, model_name=model_name.split("/")[-1]) + model.config.use_cache = ( + False # silence the warnings. Please re-enable for inference! + ) + + # Apply training with evaluation: + context.logger.info(f"training '{model_name}'") + trainer.train() + + temp_directory = tempfile.TemporaryDirectory().name + trainer.save_model(temp_directory) + + # Zip the model directory: + shutil.make_archive( + base_name="model", + format="zip", + root_dir=temp_directory, + ) + + # Log the model: + context.log_model( + key="model", + db_key=model_name.split("/")[-1], + model_file="model.zip", + tag="", + framework="Hugging Face", + ) + + +def evaluate( + context, + model_path, + data: pd.DataFrame, + model_name: str = None, + tokenizer_name: str = None, +): + """ + Evaluating the model using perplexity, for more information visit: + https://huggingface.co/docs/transformers/perplexity + + :param context: mlrun context + :param model_path: path to the model directory + :param data: the data to evaluate the model + :param model_name: name of base model + :param tokenizer_name: name of base tokenizer + """ + # Get the model artifact and file: + ( + model_file, + model_artifact, + extra_data, + ) = mlrun.artifacts.get_model(model_path) + + # Read the name: + _model_name = model_artifact.spec.db_key + + # Extract logged model files: + model_directory = os.path.join(os.path.dirname(model_file), _model_name) + with zipfile.ZipFile(model_file, "r") as zip_file: + zip_file.extractall(model_directory) + + # Loading the saved pretrained tokenizer and model: + dataset = Dataset.from_pandas(data) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + pad_token_id = tokenizer.eos_token_id + model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="cuda:0", trust_remote_code=True, load_in_8bit=True + ) + model = PeftModel.from_pretrained(model, model_directory) + model.eval() + encodings = tokenizer("\n\n".join(dataset["text"][:5]), return_tensors="pt") + + max_length = 1024 + stride = 512 + seq_len = encodings.input_ids.size(1) + + nlls = [] + prev_end_loc = 0 + for begin_loc in range(0, seq_len, stride): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + with torch.no_grad(): + outputs = model(input_ids.cuda(), labels=target_ids) + + # loss is calculated using CrossEntropyLoss which averages over valid labels + # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels + # to the left by 1. + neg_log_likelihood = outputs.loss + + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + + ppl = torch.exp(torch.stack(nlls).mean()).item() + context.log_result("perplexity", ppl) diff --git a/huggingface_dpo/item.yaml b/huggingface_dpo/item.yaml new file mode 100644 index 00000000..3eff1eed --- /dev/null +++ b/huggingface_dpo/item.yaml @@ -0,0 +1,23 @@ +apiVersion: v1 +categories: +- machine-learning +- model-training +description: doing the alignment with dpo trainer +doc: '' +example: huggingface_dpo_trainer.ipynb +generationDate: 2024-03-19:09-25 +hidden: false +icon: '' +labels: + author: pgw +maintainers: [] +marketplaceType: '' +name: huggingface-dpo-trainer +spec: + filename: huggingface_dpo_trainer.py + handler: dpo_train + image: mlrun/mlrun + kind: job + requirements: [] +url: '' +version: 1.0.0 diff --git a/huggingface_dpo/requirements.txt b/huggingface_dpo/requirements.txt new file mode 100644 index 00000000..c0384639 --- /dev/null +++ b/huggingface_dpo/requirements.txt @@ -0,0 +1,8 @@ +peft +transformers +torch +datasets +plotly +trl +mlrun +bitsandbytes diff --git a/huggingface_dpo/test_huggingface_dpo_trainer.py b/huggingface_dpo/test_huggingface_dpo_trainer.py new file mode 100644 index 00000000..db289b51 --- /dev/null +++ b/huggingface_dpo/test_huggingface_dpo_trainer.py @@ -0,0 +1,66 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import mlrun + + +def test_dpo_fn(): + dpo_trainer = mlrun.import_function("function.yaml") + model_name = "mistralai/Mistral-7B-Instruct-v0.2" + tokenizer = model_name + + ctx = mlrun.get_or_create_ctx(name="test_dpo") + train_dataset = "unalignment/toxic-dpo-v0.2" + eval_dataset = "unalignment/toxic-dpo-v0.2" + training_arguments = { + "evaluation_strategy": "steps", + "do_eval": False, + "optim": "paged_adamw_8bit", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 1, + "log_level": "info", + "save_steps": 1, + "learning_rate": 5e-7, + "eval_steps": 1, + "num_train_epochs": 1, + "max_steps": 1, + "warmup_steps": 1, + "fp16": True, + "lr_scheduler_type": "cosine", + "remove_unused_columns": True, + "gradient_checkpointing": True, + } + params = { + "model": model_name, + "tokenizer": tokenizer, + "train_dataset": train_dataset, + "eval_dataset": eval_dataset, + "peft_config": True, + "training_config": training_arguments, + "use_cuda": True, + "beta": 0.1, + } + try: + with tempfile.TemporaryDirectory() as test_directory: + dpo_trainer.run( + local=True, + params=params, + handler="dpo_train", + returns=["model"], + workdir=test_directory, + ) + except Exception as exception: + print(f"-The training failed -raised the following error: \n -{exception}")