adapter.go 17 KB


  1. // Copyright 2017 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Because SQLite3 needs CGO and windows environment does not actively have GCC
  15. // It is necessary to comment out the SQLite3 code.
  16. // The modification content is to comment the package about SQLite3
  17. package gormadapter
  18. import (
  19. "context"
  20. "errors"
  21. "fmt"
  22. "runtime"
  23. "strings"
  24. "github.com/casbin/casbin/v2/model"
  25. "github.com/casbin/casbin/v2/persist"
  26. "github.com/jackc/pgconn"
  27. "gorm.io/driver/mysql"
  28. "gorm.io/driver/postgres"
  29. //"gorm.io/driver/sqlite"
  30. "gorm.io/driver/sqlserver"
  31. "gorm.io/gorm"
  32. )
  33. const (
  34. defaultDatabaseName = "casbin"
  35. defaultTableName = "casbin_rule"
  36. )
  37. type customTableKey struct{}
  38. type CasbinRule struct {
  39. ID uint `gorm:"primaryKey;autoIncrement"`
  40. Ptype string `gorm:"size:100"`
  41. V0 string `gorm:"size:100"`
  42. V1 string `gorm:"size:100"`
  43. V2 string `gorm:"size:100"`
  44. V3 string `gorm:"size:100"`
  45. V4 string `gorm:"size:100"`
  46. V5 string `gorm:"size:100"`
  47. }
  48. func (CasbinRule) TableName() string {
  49. return "casbin_rule"
  50. }
  51. type Filter struct {
  52. PType []string
  53. V0 []string
  54. V1 []string
  55. V2 []string
  56. V3 []string
  57. V4 []string
  58. V5 []string
  59. }
  60. // Adapter represents the Gorm adapter for policy storage.
  61. type Adapter struct {
  62. driverName string
  63. dataSourceName string
  64. databaseName string
  65. tablePrefix string
  66. tableName string
  67. dbSpecified bool
  68. db *gorm.DB
  69. isFiltered bool
  70. }
  71. // finalizer is the destructor for Adapter.
  72. func finalizer(a *Adapter) {
  73. sqlDB, err := a.db.DB()
  74. if err != nil {
  75. panic(err)
  76. }
  77. err = sqlDB.Close()
  78. if err != nil {
  79. panic(err)
  80. }
  81. }
  82. // NewAdapter is the constructor for Adapter.
  83. // Params : databaseName,tableName,dbSpecified
  84. // databaseName,{tableName/dbSpecified}
  85. // {database/dbSpecified}
  86. // databaseName and tableName are user defined.
  87. // Their default value are "casbin" and "casbin_rule"
  88. //
  89. // dbSpecified is an optional bool parameter. The default value is false.
  90. // It's up to whether you have specified an existing DB in dataSourceName.
  91. // If dbSpecified == true, you need to make sure the DB in dataSourceName exists.
  92. // If dbSpecified == false, the adapter will automatically create a DB named databaseName.
  93. func NewAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) {
  94. a := &Adapter{}
  95. a.driverName = driverName
  96. a.dataSourceName = dataSourceName
  97. a.tableName = defaultTableName
  98. a.databaseName = defaultDatabaseName
  99. a.dbSpecified = false
  100. if len(params) == 1 {
  101. switch p1 := params[0].(type) {
  102. case bool:
  103. a.dbSpecified = p1
  104. case string:
  105. a.databaseName = p1
  106. default:
  107. return nil, errors.New("wrong format")
  108. }
  109. } else if len(params) == 2 {
  110. switch p2 := params[1].(type) {
  111. case bool:
  112. a.dbSpecified = p2
  113. p1, ok := params[0].(string)
  114. if !ok {
  115. return nil, errors.New("wrong format")
  116. }
  117. a.databaseName = p1
  118. case string:
  119. p1, ok := params[0].(string)
  120. if !ok {
  121. return nil, errors.New("wrong format")
  122. }
  123. a.databaseName = p1
  124. a.tableName = p2
  125. default:
  126. return nil, errors.New("wrong format")
  127. }
  128. } else if len(params) == 3 {
  129. if p3, ok := params[2].(bool); ok {
  130. a.dbSpecified = p3
  131. a.databaseName = params[0].(string)
  132. a.tableName = params[1].(string)
  133. } else {
  134. return nil, errors.New("wrong format")
  135. }
  136. } else if len(params) != 0 {
  137. return nil, errors.New("too many parameters")
  138. }
  139. // Open the DB, create it if not existed.
  140. err := a.open()
  141. if err != nil {
  142. return nil, err
  143. }
  144. // Call the destructor when the object is released.
  145. runtime.SetFinalizer(a, finalizer)
  146. return a, nil
  147. }
  148. // NewAdapterByDBUseTableName creates gorm-adapter by an existing Gorm instance and the specified table prefix and table name
  149. // Example: gormadapter.NewAdapterByDBUseTableName(&db, "cms", "casbin") Automatically generate table name like this "cms_casbin"
  150. func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*Adapter, error) {
  151. if len(tableName) == 0 {
  152. tableName = defaultTableName
  153. }
  154. a := &Adapter{
  155. tablePrefix: prefix,
  156. tableName: tableName,
  157. }
  158. a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})
  159. err := a.createTable()
  160. if err != nil {
  161. return nil, err
  162. }
  163. return a, nil
  164. }
  165. // NewFilteredAdapter is the constructor for FilteredAdapter.
  166. // Casbin will not automatically call LoadPolicy() for a filtered adapter.
  167. func NewFilteredAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) {
  168. adapter, err := NewAdapter(driverName, dataSourceName, params...)
  169. if err != nil {
  170. return nil, err
  171. }
  172. adapter.isFiltered = true
  173. return adapter, err
  174. }
  175. // NewAdapterByDB creates gorm-adapter by an existing Gorm instance
  176. func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
  177. return NewAdapterByDBUseTableName(db, "", defaultTableName)
  178. }
  179. func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}) (*Adapter, error) {
  180. ctx := db.Statement.Context
  181. if ctx == nil {
  182. ctx = context.Background()
  183. }
  184. ctx = context.WithValue(ctx, customTableKey{}, t)
  185. return NewAdapterByDBUseTableName(db.WithContext(ctx), "", defaultTableName)
  186. }
  187. func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
  188. var err error
  189. var db *gorm.DB
  190. if driverName == "postgres" {
  191. db, err = gorm.Open(postgres.Open(dataSourceName), &gorm.Config{})
  192. } else if driverName == "mysql" {
  193. db, err = gorm.Open(mysql.Open(dataSourceName), &gorm.Config{})
  194. } else if driverName == "sqlserver" {
  195. db, err = gorm.Open(sqlserver.Open(dataSourceName), &gorm.Config{})
  196. } else {
  197. return nil, errors.New("database dialect is not supported")
  198. }
  199. // If you need SQLite, fill in the code above
  200. /* else if driverName == "sqlite3" {
  201. db, err = gorm.Open(sqlite.Open(dataSourceName), &gorm.Config{})
  202. } */
  203. if err != nil {
  204. return nil, err
  205. }
  206. return db, err
  207. }
  208. func (a *Adapter) createDatabase() error {
  209. var err error
  210. db, err := openDBConnection(a.driverName, a.dataSourceName)
  211. if err != nil {
  212. return err
  213. }
  214. if a.driverName == "postgres" {
  215. if err = db.Exec("CREATE DATABASE " + a.databaseName).Error; err != nil {
  216. // 42P04 is duplicate_database
  217. if err.(*pgconn.PgError).Code == "42P04" {
  218. return nil
  219. }
  220. }
  221. }
  222. // If you need SQLite, fill in the code above
  223. /* else if a.driverName != "sqlite3" {
  224. err = db.Exec("CREATE DATABASE IF NOT EXISTS " + a.databaseName).Error
  225. } */
  226. if err != nil {
  227. return err
  228. }
  229. return nil
  230. }
  231. func (a *Adapter) open() error {
  232. var err error
  233. var db *gorm.DB
  234. if a.dbSpecified {
  235. db, err = openDBConnection(a.driverName, a.dataSourceName)
  236. if err != nil {
  237. return err
  238. }
  239. } else {
  240. if err = a.createDatabase(); err != nil {
  241. return err
  242. }
  243. if a.driverName == "postgres" {
  244. db, err = openDBConnection(a.driverName, a.dataSourceName+" dbname="+a.databaseName)
  245. } else {
  246. db, err = openDBConnection(a.driverName, a.dataSourceName+a.databaseName)
  247. }
  248. // If you need SQLite, fill in the code above
  249. /* else if a.driverName == "sqlite3" {
  250. db, err = openDBConnection(a.driverName, a.dataSourceName)
  251. } */
  252. if err != nil {
  253. return err
  254. }
  255. }
  256. a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{})
  257. return a.createTable()
  258. }
  259. func (a *Adapter) close() error {
  260. a.db = nil
  261. return nil
  262. }
  263. // getTableInstance return the dynamic table name
  264. func (a *Adapter) getTableInstance() *CasbinRule {
  265. return &CasbinRule{}
  266. }
  267. func (a *Adapter) getFullTableName() string {
  268. if a.tablePrefix != "" {
  269. return a.tablePrefix + "_" + a.tableName
  270. }
  271. return a.tableName
  272. }
  273. func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB {
  274. return func(db *gorm.DB) *gorm.DB {
  275. tableName := a.getFullTableName()
  276. return db.Table(tableName)
  277. }
  278. }
  279. func (a *Adapter) createTable() error {
  280. t := a.db.Statement.Context.Value(customTableKey{})
  281. if t != nil {
  282. return a.db.AutoMigrate(t)
  283. }
  284. t = a.getTableInstance()
  285. if err := a.db.AutoMigrate(t); err != nil {
  286. return err
  287. }
  288. tableName := a.getFullTableName()
  289. index := "idx_" + tableName
  290. hasIndex := a.db.Migrator().HasIndex(t, index)
  291. if !hasIndex {
  292. if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5)", index, tableName)).Error; err != nil {
  293. return err
  294. }
  295. }
  296. return nil
  297. }
  298. func (a *Adapter) dropTable() error {
  299. t := a.db.Statement.Context.Value(customTableKey{})
  300. if t == nil {
  301. return a.db.Migrator().DropTable(a.getTableInstance())
  302. }
  303. return a.db.Migrator().DropTable(t)
  304. }
  305. func loadPolicyLine(line CasbinRule, model model.Model) {
  306. var p = []string{line.Ptype,
  307. line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}
  308. var lineText string
  309. if line.V5 != "" {
  310. lineText = strings.Join(p, ", ")
  311. } else if line.V4 != "" {
  312. lineText = strings.Join(p[:6], ", ")
  313. } else if line.V3 != "" {
  314. lineText = strings.Join(p[:5], ", ")
  315. } else if line.V2 != "" {
  316. lineText = strings.Join(p[:4], ", ")
  317. } else if line.V1 != "" {
  318. lineText = strings.Join(p[:3], ", ")
  319. } else if line.V0 != "" {
  320. lineText = strings.Join(p[:2], ", ")
  321. }
  322. persist.LoadPolicyLine(lineText, model)
  323. }
  324. // LoadPolicy loads policy from database.
  325. func (a *Adapter) LoadPolicy(model model.Model) error {
  326. var lines []CasbinRule
  327. if err := a.db.Order("ID").Find(&lines).Error; err != nil {
  328. return err
  329. }
  330. for _, line := range lines {
  331. loadPolicyLine(line, model)
  332. }
  333. return nil
  334. }
  335. // LoadFilteredPolicy loads only policy rules that match the filter.
  336. func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
  337. var lines []CasbinRule
  338. filterValue, ok := filter.(Filter)
  339. if !ok {
  340. return errors.New("invalid filter type")
  341. }
  342. if err := a.db.Scopes(a.filterQuery(a.db, filterValue)).Order("ID").Find(&lines).Error; err != nil {
  343. return err
  344. }
  345. for _, line := range lines {
  346. loadPolicyLine(line, model)
  347. }
  348. a.isFiltered = true
  349. return nil
  350. }
  351. // IsFiltered returns true if the loaded policy has been filtered.
  352. func (a *Adapter) IsFiltered() bool {
  353. return a.isFiltered
  354. }
  355. // filterQuery builds the gorm query to match the rule filter to use within a scope.
  356. func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gorm.DB {
  357. return func(db *gorm.DB) *gorm.DB {
  358. if len(filter.PType) > 0 {
  359. db = db.Where("ptype in (?)", filter.PType)
  360. }
  361. if len(filter.V0) > 0 {
  362. db = db.Where("v0 in (?)", filter.V0)
  363. }
  364. if len(filter.V1) > 0 {
  365. db = db.Where("v1 in (?)", filter.V1)
  366. }
  367. if len(filter.V2) > 0 {
  368. db = db.Where("v2 in (?)", filter.V2)
  369. }
  370. if len(filter.V3) > 0 {
  371. db = db.Where("v3 in (?)", filter.V3)
  372. }
  373. if len(filter.V4) > 0 {
  374. db = db.Where("v4 in (?)", filter.V4)
  375. }
  376. if len(filter.V5) > 0 {
  377. db = db.Where("v5 in (?)", filter.V5)
  378. }
  379. return db
  380. }
  381. }
  382. func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {
  383. line := a.getTableInstance()
  384. line.Ptype = ptype
  385. if len(rule) > 0 {
  386. line.V0 = rule[0]
  387. }
  388. if len(rule) > 1 {
  389. line.V1 = rule[1]
  390. }
  391. if len(rule) > 2 {
  392. line.V2 = rule[2]
  393. }
  394. if len(rule) > 3 {
  395. line.V3 = rule[3]
  396. }
  397. if len(rule) > 4 {
  398. line.V4 = rule[4]
  399. }
  400. if len(rule) > 5 {
  401. line.V5 = rule[5]
  402. }
  403. return *line
  404. }
  405. // SavePolicy saves policy to database.
  406. func (a *Adapter) SavePolicy(model model.Model) error {
  407. err := a.dropTable()
  408. if err != nil {
  409. return err
  410. }
  411. err = a.createTable()
  412. if err != nil {
  413. return err
  414. }
  415. for ptype, ast := range model["p"] {
  416. for _, rule := range ast.Policy {
  417. line := a.savePolicyLine(ptype, rule)
  418. err := a.db.Create(&line).Error
  419. if err != nil {
  420. return err
  421. }
  422. }
  423. }
  424. for ptype, ast := range model["g"] {
  425. for _, rule := range ast.Policy {
  426. line := a.savePolicyLine(ptype, rule)
  427. err := a.db.Create(&line).Error
  428. if err != nil {
  429. return err
  430. }
  431. }
  432. }
  433. return nil
  434. }
  435. // AddPolicy adds a policy rule to the storage.
  436. func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
  437. line := a.savePolicyLine(ptype, rule)
  438. err := a.db.Create(&line).Error
  439. return err
  440. }
  441. // RemovePolicy removes a policy rule from the storage.
  442. func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
  443. line := a.savePolicyLine(ptype, rule)
  444. err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
  445. return err
  446. }
  447. // AddPolicies adds multiple policy rules to the storage.
  448. func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error {
  449. return a.db.Transaction(func(tx *gorm.DB) error {
  450. for _, rule := range rules {
  451. line := a.savePolicyLine(ptype, rule)
  452. if err := tx.Create(&line).Error; err != nil {
  453. return err
  454. }
  455. }
  456. return nil
  457. })
  458. }
  459. // RemovePolicies removes multiple policy rules from the storage.
  460. func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error {
  461. return a.db.Transaction(func(tx *gorm.DB) error {
  462. for _, rule := range rules {
  463. line := a.savePolicyLine(ptype, rule)
  464. if err := a.rawDelete(tx, line); err != nil { //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
  465. return err
  466. }
  467. }
  468. return nil
  469. })
  470. }
  471. // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
  472. func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
  473. line := a.getTableInstance()
  474. line.Ptype = ptype
  475. if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
  476. line.V0 = fieldValues[0-fieldIndex]
  477. }
  478. if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
  479. line.V1 = fieldValues[1-fieldIndex]
  480. }
  481. if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
  482. line.V2 = fieldValues[2-fieldIndex]
  483. }
  484. if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
  485. line.V3 = fieldValues[3-fieldIndex]
  486. }
  487. if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
  488. line.V4 = fieldValues[4-fieldIndex]
  489. }
  490. if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
  491. line.V5 = fieldValues[5-fieldIndex]
  492. }
  493. err := a.rawDelete(a.db, *line)
  494. return err
  495. }
  496. func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
  497. queryArgs := []interface{}{line.Ptype}
  498. queryStr := "ptype = ?"
  499. if line.V0 != "" {
  500. queryStr += " and v0 = ?"
  501. queryArgs = append(queryArgs, line.V0)
  502. }
  503. if line.V1 != "" {
  504. queryStr += " and v1 = ?"
  505. queryArgs = append(queryArgs, line.V1)
  506. }
  507. if line.V2 != "" {
  508. queryStr += " and v2 = ?"
  509. queryArgs = append(queryArgs, line.V2)
  510. }
  511. if line.V3 != "" {
  512. queryStr += " and v3 = ?"
  513. queryArgs = append(queryArgs, line.V3)
  514. }
  515. if line.V4 != "" {
  516. queryStr += " and v4 = ?"
  517. queryArgs = append(queryArgs, line.V4)
  518. }
  519. if line.V5 != "" {
  520. queryStr += " and v5 = ?"
  521. queryArgs = append(queryArgs, line.V5)
  522. }
  523. args := append([]interface{}{queryStr}, queryArgs...)
  524. err := db.Delete(a.getTableInstance(), args...).Error
  525. return err
  526. }
  527. func appendWhere(line CasbinRule) (string, []interface{}) {
  528. queryArgs := []interface{}{line.Ptype}
  529. queryStr := "ptype = ?"
  530. if line.V0 != "" {
  531. queryStr += " and v0 = ?"
  532. queryArgs = append(queryArgs, line.V0)
  533. }
  534. if line.V1 != "" {
  535. queryStr += " and v1 = ?"
  536. queryArgs = append(queryArgs, line.V1)
  537. }
  538. if line.V2 != "" {
  539. queryStr += " and v2 = ?"
  540. queryArgs = append(queryArgs, line.V2)
  541. }
  542. if line.V3 != "" {
  543. queryStr += " and v3 = ?"
  544. queryArgs = append(queryArgs, line.V3)
  545. }
  546. if line.V4 != "" {
  547. queryStr += " and v4 = ?"
  548. queryArgs = append(queryArgs, line.V4)
  549. }
  550. if line.V5 != "" {
  551. queryStr += " and v5 = ?"
  552. queryArgs = append(queryArgs, line.V5)
  553. }
  554. return queryStr, queryArgs
  555. }
  556. // UpdatePolicy updates a new policy rule to DB.
  557. func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []string) error {
  558. oldLine := a.savePolicyLine(ptype, oldRule)
  559. newLine := a.savePolicyLine(ptype, newPolicy)
  560. return a.db.Model(&oldLine).Where(&oldLine).Updates(newLine).Error
  561. }
  562. func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error {
  563. oldPolicies := make([]CasbinRule, 0, len(oldRules))
  564. newPolicies := make([]CasbinRule, 0, len(oldRules))
  565. for _, oldRule := range oldRules {
  566. oldPolicies = append(oldPolicies, a.savePolicyLine(ptype, oldRule))
  567. }
  568. for _, newRule := range newRules {
  569. newPolicies = append(newPolicies, a.savePolicyLine(ptype, newRule))
  570. }
  571. tx := a.db.Begin()
  572. for i := range oldPolicies {
  573. if err := tx.Model(&oldPolicies[i]).Where(&oldPolicies[i]).Updates(newPolicies[i]).Error; err != nil {
  574. tx.Rollback()
  575. return err
  576. }
  577. }
  578. return tx.Commit().Error
  579. }